mirror of
https://github.com/django/daphne.git
synced 2024-11-25 09:13:44 +03:00
Merge pull request #57 from raphaelm/issue55
Fix #55 -- Optionally parse X-Forwarded-For header
This commit is contained in:
commit
dc98b09dfd
|
@ -96,6 +96,14 @@ class CommandLineInterface(object):
|
||||||
help='The setting for the ASGI root_path variable',
|
help='The setting for the ASGI root_path variable',
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--proxy-headers',
|
||||||
|
dest='proxy_headers',
|
||||||
|
help='Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the '
|
||||||
|
'client address',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def entrypoint(cls):
|
def entrypoint(cls):
|
||||||
|
@ -154,4 +162,6 @@ class CommandLineInterface(object):
|
||||||
ws_protocols=args.ws_protocols,
|
ws_protocols=args.ws_protocols,
|
||||||
root_path=args.root_path,
|
root_path=args.root_path,
|
||||||
verbosity=args.verbosity,
|
verbosity=args.verbosity,
|
||||||
|
proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None,
|
||||||
|
proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None,
|
||||||
).run()
|
).run()
|
||||||
|
|
|
@ -9,6 +9,7 @@ from six.moves.urllib_parse import unquote, unquote_plus
|
||||||
from twisted.protocols.policies import ProtocolWrapper
|
from twisted.protocols.policies import ProtocolWrapper
|
||||||
from twisted.web import http
|
from twisted.web import http
|
||||||
|
|
||||||
|
from .utils import parse_x_forwarded_for
|
||||||
from .ws_protocol import WebSocketProtocol, WebSocketFactory
|
from .ws_protocol import WebSocketProtocol, WebSocketFactory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -67,6 +68,15 @@ class WebRequest(http.Request):
|
||||||
else:
|
else:
|
||||||
self.client_addr = None
|
self.client_addr = None
|
||||||
self.server_addr = None
|
self.server_addr = None
|
||||||
|
|
||||||
|
if self.factory.proxy_forwarded_address_header:
|
||||||
|
self.client_addr = parse_x_forwarded_for(
|
||||||
|
self.requestHeaders,
|
||||||
|
self.factory.proxy_forwarded_address_header,
|
||||||
|
self.factory.proxy_forwarded_port_header,
|
||||||
|
self.client_addr
|
||||||
|
)
|
||||||
|
|
||||||
# Check for unicodeish path (or it'll crash when trying to parse)
|
# Check for unicodeish path (or it'll crash when trying to parse)
|
||||||
try:
|
try:
|
||||||
self.path.decode("ascii")
|
self.path.decode("ascii")
|
||||||
|
@ -281,7 +291,7 @@ class HTTPFactory(http.HTTPFactory):
|
||||||
|
|
||||||
protocol = HTTPProtocol
|
protocol = HTTPProtocol
|
||||||
|
|
||||||
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30):
|
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30, proxy_forwarded_address_header=None, proxy_forwarded_port_header=None):
|
||||||
http.HTTPFactory.__init__(self)
|
http.HTTPFactory.__init__(self)
|
||||||
self.channel_layer = channel_layer
|
self.channel_layer = channel_layer
|
||||||
self.action_logger = action_logger
|
self.action_logger = action_logger
|
||||||
|
@ -289,6 +299,8 @@ class HTTPFactory(http.HTTPFactory):
|
||||||
self.websocket_timeout = websocket_timeout
|
self.websocket_timeout = websocket_timeout
|
||||||
self.websocket_connect_timeout = websocket_connect_timeout
|
self.websocket_connect_timeout = websocket_connect_timeout
|
||||||
self.ping_interval = ping_interval
|
self.ping_interval = ping_interval
|
||||||
|
self.proxy_forwarded_address_header = proxy_forwarded_address_header
|
||||||
|
self.proxy_forwarded_port_header = proxy_forwarded_port_header
|
||||||
# We track all sub-protocols for response channel mapping
|
# We track all sub-protocols for response channel mapping
|
||||||
self.reply_protocols = {}
|
self.reply_protocols = {}
|
||||||
# Make a factory for WebSocket protocols
|
# Make a factory for WebSocket protocols
|
||||||
|
|
|
@ -26,6 +26,8 @@ class Server(object):
|
||||||
ping_timeout=30,
|
ping_timeout=30,
|
||||||
ws_protocols=None,
|
ws_protocols=None,
|
||||||
root_path="",
|
root_path="",
|
||||||
|
proxy_forwarded_address_header=None,
|
||||||
|
proxy_forwarded_port_header=None,
|
||||||
verbosity=1
|
verbosity=1
|
||||||
):
|
):
|
||||||
self.channel_layer = channel_layer
|
self.channel_layer = channel_layer
|
||||||
|
@ -38,6 +40,8 @@ class Server(object):
|
||||||
self.http_timeout = http_timeout
|
self.http_timeout = http_timeout
|
||||||
self.ping_interval = ping_interval
|
self.ping_interval = ping_interval
|
||||||
self.ping_timeout = ping_timeout
|
self.ping_timeout = ping_timeout
|
||||||
|
self.proxy_forwarded_address_header = proxy_forwarded_address_header
|
||||||
|
self.proxy_forwarded_port_header = proxy_forwarded_port_header
|
||||||
# If they did not provide a websocket timeout, default it to the
|
# If they did not provide a websocket timeout, default it to the
|
||||||
# channel layer's group_expiry value if present, or one day if not.
|
# channel layer's group_expiry value if present, or one day if not.
|
||||||
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
|
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
|
||||||
|
@ -55,6 +59,8 @@ class Server(object):
|
||||||
ping_timeout=self.ping_timeout,
|
ping_timeout=self.ping_timeout,
|
||||||
ws_protocols=self.ws_protocols,
|
ws_protocols=self.ws_protocols,
|
||||||
root_path=self.root_path,
|
root_path=self.root_path,
|
||||||
|
proxy_forwarded_address_header=self.proxy_forwarded_address_header,
|
||||||
|
proxy_forwarded_port_header=self.proxy_forwarded_port_header
|
||||||
)
|
)
|
||||||
if self.verbosity <= 1:
|
if self.verbosity <= 1:
|
||||||
# Redirect the Twisted log to nowhere
|
# Redirect the Twisted log to nowhere
|
||||||
|
|
|
@ -93,3 +93,51 @@ class TestHTTPProtocol(TestCase):
|
||||||
# Get the disconnection notification
|
# Get the disconnection notification
|
||||||
_, disconnect_message = self.channel_layer.receive_many(["http.disconnect"])
|
_, disconnect_message = self.channel_layer.receive_many(["http.disconnect"])
|
||||||
self.assertEqual(disconnect_message['path'], "/te st-à/")
|
self.assertEqual(disconnect_message['path'], "/te st-à/")
|
||||||
|
|
||||||
|
def test_x_forwarded_for_ignored(self):
|
||||||
|
"""
|
||||||
|
Tests basic HTTP parsing
|
||||||
|
"""
|
||||||
|
self.proto.dataReceived(
|
||||||
|
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
|
||||||
|
b"Host: somewhere.com\r\n" +
|
||||||
|
b"X-Forwarded-For: 10.1.2.3\r\n" +
|
||||||
|
b"X-Forwarded-Port: 80\r\n" +
|
||||||
|
b"\r\n"
|
||||||
|
)
|
||||||
|
# Get the resulting message off of the channel layer
|
||||||
|
_, message = self.channel_layer.receive_many(["http.request"])
|
||||||
|
self.assertEqual(message['client'], ['192.168.1.1', 54321])
|
||||||
|
|
||||||
|
def test_x_forwarded_for_parsed(self):
|
||||||
|
"""
|
||||||
|
Tests basic HTTP parsing
|
||||||
|
"""
|
||||||
|
self.factory.proxy_forwarded_address_header = 'X-Forwarded-For'
|
||||||
|
self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port'
|
||||||
|
self.proto.dataReceived(
|
||||||
|
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
|
||||||
|
b"Host: somewhere.com\r\n" +
|
||||||
|
b"X-Forwarded-For: 10.1.2.3\r\n" +
|
||||||
|
b"X-Forwarded-Port: 80\r\n" +
|
||||||
|
b"\r\n"
|
||||||
|
)
|
||||||
|
# Get the resulting message off of the channel layer
|
||||||
|
_, message = self.channel_layer.receive_many(["http.request"])
|
||||||
|
self.assertEqual(message['client'], ['10.1.2.3', 80])
|
||||||
|
|
||||||
|
def test_x_forwarded_for_port_missing(self):
|
||||||
|
"""
|
||||||
|
Tests basic HTTP parsing
|
||||||
|
"""
|
||||||
|
self.factory.proxy_forwarded_address_header = 'X-Forwarded-For'
|
||||||
|
self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port'
|
||||||
|
self.proto.dataReceived(
|
||||||
|
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
|
||||||
|
b"Host: somewhere.com\r\n" +
|
||||||
|
b"X-Forwarded-For: 10.1.2.3\r\n" +
|
||||||
|
b"\r\n"
|
||||||
|
)
|
||||||
|
# Get the resulting message off of the channel layer
|
||||||
|
_, message = self.channel_layer.receive_many(["http.request"])
|
||||||
|
self.assertEqual(message['client'], ['10.1.2.3', 0])
|
||||||
|
|
61
daphne/tests/test_utils.py
Normal file
61
daphne/tests/test_utils.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
from twisted.web.http_headers import Headers
|
||||||
|
|
||||||
|
from ..utils import parse_x_forwarded_for
|
||||||
|
|
||||||
|
|
||||||
|
class TestXForwardedForParsing(TestCase):
|
||||||
|
"""
|
||||||
|
Tests that the parse_x_forwarded_for util correcly parses headers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
headers = Headers({
|
||||||
|
b'X-Forwarded-For': [b'10.1.2.3'],
|
||||||
|
b'X-Forwarded-Port': [b'1234']
|
||||||
|
})
|
||||||
|
self.assertEqual(
|
||||||
|
parse_x_forwarded_for(headers),
|
||||||
|
['10.1.2.3', 1234]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_address_only(self):
|
||||||
|
headers = Headers({
|
||||||
|
b'X-Forwarded-For': [b'10.1.2.3'],
|
||||||
|
})
|
||||||
|
self.assertEqual(
|
||||||
|
parse_x_forwarded_for(headers),
|
||||||
|
['10.1.2.3', 0]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_port_in_address(self):
|
||||||
|
headers = Headers({
|
||||||
|
b'X-Forwarded-For': [b'10.1.2.3:5123'],
|
||||||
|
})
|
||||||
|
self.assertEqual(
|
||||||
|
parse_x_forwarded_for(headers),
|
||||||
|
['10.1.2.3', 5123]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multiple_proxys(self):
|
||||||
|
headers = Headers({
|
||||||
|
b'X-Forwarded-For': [b'10.1.2.3, 10.1.2.4'],
|
||||||
|
})
|
||||||
|
self.assertEqual(
|
||||||
|
parse_x_forwarded_for(headers),
|
||||||
|
['10.1.2.4', 0]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_original(self):
|
||||||
|
headers = Headers({})
|
||||||
|
self.assertEqual(
|
||||||
|
parse_x_forwarded_for(headers, original=['127.0.0.1', 80]),
|
||||||
|
['127.0.0.1', 80]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_original(self):
|
||||||
|
headers = Headers({})
|
||||||
|
self.assertIsNone(parse_x_forwarded_for(headers))
|
48
daphne/utils.py
Normal file
48
daphne/utils.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
|
||||||
|
|
||||||
|
def parse_x_forwarded_for(headers,
|
||||||
|
address_header_name='X-Forwarded-For',
|
||||||
|
port_header_name='X-Forwarded-Port',
|
||||||
|
original=None):
|
||||||
|
"""
|
||||||
|
Parses an X-Forwarded-For header and returns a host/port pair as a list.
|
||||||
|
|
||||||
|
@param headers: The twisted-style object containing a request's headers
|
||||||
|
@param address_header_name: The name of the expected host header
|
||||||
|
@param port_header_name: The name of the expected port header
|
||||||
|
@param original: A host/port pair that should be returned if the headers are not in the request
|
||||||
|
@return: A list containing a host (string) as the first entry and a port (int) as the second.
|
||||||
|
"""
|
||||||
|
if not address_header_name:
|
||||||
|
return original
|
||||||
|
|
||||||
|
address_header_name = address_header_name.lower().encode("utf-8")
|
||||||
|
result = original
|
||||||
|
if headers.hasHeader(address_header_name):
|
||||||
|
address_value = headers.getRawHeaders(address_header_name)[0].decode("utf-8")
|
||||||
|
|
||||||
|
if ',' in address_value:
|
||||||
|
address_value = address_value.split(",")[-1].strip()
|
||||||
|
|
||||||
|
if ':' in address_value:
|
||||||
|
address_host, address_port = address_value.split(':')
|
||||||
|
result = [address_host, 0]
|
||||||
|
try:
|
||||||
|
result[1] = int(address_port)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
result = [address_value, 0]
|
||||||
|
|
||||||
|
if port_header_name:
|
||||||
|
# We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For
|
||||||
|
# header to avoid inconsistent results.
|
||||||
|
port_header_name = port_header_name.lower().encode("utf-8")
|
||||||
|
if headers.hasHeader(port_header_name):
|
||||||
|
port_value = headers.getRawHeaders(port_header_name)[0].decode("utf-8")
|
||||||
|
try:
|
||||||
|
result[1] = int(port_value)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return result
|
|
@ -9,6 +9,8 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory, ConnectionDeny
|
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory, ConnectionDeny
|
||||||
|
|
||||||
|
from .utils import parse_x_forwarded_for
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,6 +56,15 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
else:
|
else:
|
||||||
self.client_addr = None
|
self.client_addr = None
|
||||||
self.server_addr = None
|
self.server_addr = None
|
||||||
|
|
||||||
|
if self.factory.proxy_forwarded_address_header:
|
||||||
|
self.client_addr = parse_x_forwarded_for(
|
||||||
|
self.requestHeaders,
|
||||||
|
self.main_factory.proxy_forwarded_address_header,
|
||||||
|
self.main_factory.proxy_forwarded_port_header,
|
||||||
|
self.client_addr
|
||||||
|
)
|
||||||
|
|
||||||
# Make initial request info dict from request (we only have it here)
|
# Make initial request info dict from request (we only have it here)
|
||||||
self.path = request.path.encode("ascii")
|
self.path = request.path.encode("ascii")
|
||||||
self.request_info = {
|
self.request_info = {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user