mirror of
				https://github.com/django/daphne.git
				synced 2025-10-31 15:57:25 +03:00 
			
		
		
		
	Merge pull request #57 from raphaelm/issue55
Fix #55 -- Optionally parse X-Forwarded-For header
This commit is contained in:
		
						commit
						dc98b09dfd
					
				|  | @ -96,6 +96,14 @@ class CommandLineInterface(object): | |||
|             help='The setting for the ASGI root_path variable', | ||||
|             default="", | ||||
|         ) | ||||
|         self.parser.add_argument( | ||||
|             '--proxy-headers', | ||||
|             dest='proxy_headers', | ||||
|             help='Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the ' | ||||
|                  'client address', | ||||
|             default=False, | ||||
|             action='store_true', | ||||
|         ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def entrypoint(cls): | ||||
|  | @ -154,4 +162,6 @@ class CommandLineInterface(object): | |||
|             ws_protocols=args.ws_protocols, | ||||
|             root_path=args.root_path, | ||||
|             verbosity=args.verbosity, | ||||
|             proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None, | ||||
|             proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None, | ||||
|         ).run() | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ from six.moves.urllib_parse import unquote, unquote_plus | |||
| from twisted.protocols.policies import ProtocolWrapper | ||||
| from twisted.web import http | ||||
| 
 | ||||
| from .utils import parse_x_forwarded_for | ||||
| from .ws_protocol import WebSocketProtocol, WebSocketFactory | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
|  | @ -67,6 +68,15 @@ class WebRequest(http.Request): | |||
|             else: | ||||
|                 self.client_addr = None | ||||
|                 self.server_addr = None | ||||
| 
 | ||||
|             if self.factory.proxy_forwarded_address_header: | ||||
|                 self.client_addr = parse_x_forwarded_for( | ||||
|                     self.requestHeaders, | ||||
|                     self.factory.proxy_forwarded_address_header, | ||||
|                     self.factory.proxy_forwarded_port_header, | ||||
|                     self.client_addr | ||||
|                 ) | ||||
| 
 | ||||
|             # Check for unicodeish path (or it'll crash when trying to parse) | ||||
|             try: | ||||
|                 self.path.decode("ascii") | ||||
|  | @ -281,7 +291,7 @@ class HTTPFactory(http.HTTPFactory): | |||
| 
 | ||||
|     protocol = HTTPProtocol | ||||
| 
 | ||||
|     def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30): | ||||
|     def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30, proxy_forwarded_address_header=None, proxy_forwarded_port_header=None): | ||||
|         http.HTTPFactory.__init__(self) | ||||
|         self.channel_layer = channel_layer | ||||
|         self.action_logger = action_logger | ||||
|  | @ -289,6 +299,8 @@ class HTTPFactory(http.HTTPFactory): | |||
|         self.websocket_timeout = websocket_timeout | ||||
|         self.websocket_connect_timeout = websocket_connect_timeout | ||||
|         self.ping_interval = ping_interval | ||||
|         self.proxy_forwarded_address_header = proxy_forwarded_address_header | ||||
|         self.proxy_forwarded_port_header = proxy_forwarded_port_header | ||||
|         # We track all sub-protocols for response channel mapping | ||||
|         self.reply_protocols = {} | ||||
|         # Make a factory for WebSocket protocols | ||||
|  |  | |||
|  | @ -26,6 +26,8 @@ class Server(object): | |||
|         ping_timeout=30, | ||||
|         ws_protocols=None, | ||||
|         root_path="", | ||||
|         proxy_forwarded_address_header=None, | ||||
|         proxy_forwarded_port_header=None, | ||||
|         verbosity=1 | ||||
|     ): | ||||
|         self.channel_layer = channel_layer | ||||
|  | @ -38,6 +40,8 @@ class Server(object): | |||
|         self.http_timeout = http_timeout | ||||
|         self.ping_interval = ping_interval | ||||
|         self.ping_timeout = ping_timeout | ||||
|         self.proxy_forwarded_address_header = proxy_forwarded_address_header | ||||
|         self.proxy_forwarded_port_header = proxy_forwarded_port_header | ||||
|         # If they did not provide a websocket timeout, default it to the | ||||
|         # channel layer's group_expiry value if present, or one day if not. | ||||
|         self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) | ||||
|  | @ -55,6 +59,8 @@ class Server(object): | |||
|             ping_timeout=self.ping_timeout, | ||||
|             ws_protocols=self.ws_protocols, | ||||
|             root_path=self.root_path, | ||||
|             proxy_forwarded_address_header=self.proxy_forwarded_address_header, | ||||
|             proxy_forwarded_port_header=self.proxy_forwarded_port_header | ||||
|         ) | ||||
|         if self.verbosity <= 1: | ||||
|             # Redirect the Twisted log to nowhere | ||||
|  |  | |||
|  | @ -93,3 +93,51 @@ class TestHTTPProtocol(TestCase): | |||
|         # Get the disconnection notification | ||||
|         _, disconnect_message = self.channel_layer.receive_many(["http.disconnect"]) | ||||
|         self.assertEqual(disconnect_message['path'], "/te st-à/") | ||||
| 
 | ||||
|     def test_x_forwarded_for_ignored(self): | ||||
|         """ | ||||
|         Tests basic HTTP parsing | ||||
|         """ | ||||
|         self.proto.dataReceived( | ||||
|             b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + | ||||
|             b"Host: somewhere.com\r\n" + | ||||
|             b"X-Forwarded-For: 10.1.2.3\r\n" + | ||||
|             b"X-Forwarded-Port: 80\r\n" + | ||||
|             b"\r\n" | ||||
|         ) | ||||
|         # Get the resulting message off of the channel layer | ||||
|         _, message = self.channel_layer.receive_many(["http.request"]) | ||||
|         self.assertEqual(message['client'], ['192.168.1.1', 54321]) | ||||
| 
 | ||||
|     def test_x_forwarded_for_parsed(self): | ||||
|         """ | ||||
|         Tests basic HTTP parsing | ||||
|         """ | ||||
|         self.factory.proxy_forwarded_address_header = 'X-Forwarded-For' | ||||
|         self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port' | ||||
|         self.proto.dataReceived( | ||||
|             b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + | ||||
|             b"Host: somewhere.com\r\n" + | ||||
|             b"X-Forwarded-For: 10.1.2.3\r\n" + | ||||
|             b"X-Forwarded-Port: 80\r\n" + | ||||
|             b"\r\n" | ||||
|         ) | ||||
|         # Get the resulting message off of the channel layer | ||||
|         _, message = self.channel_layer.receive_many(["http.request"]) | ||||
|         self.assertEqual(message['client'], ['10.1.2.3', 80]) | ||||
| 
 | ||||
|     def test_x_forwarded_for_port_missing(self): | ||||
|         """ | ||||
|         Tests basic HTTP parsing | ||||
|         """ | ||||
|         self.factory.proxy_forwarded_address_header = 'X-Forwarded-For' | ||||
|         self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port' | ||||
|         self.proto.dataReceived( | ||||
|             b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + | ||||
|             b"Host: somewhere.com\r\n" + | ||||
|             b"X-Forwarded-For: 10.1.2.3\r\n" + | ||||
|             b"\r\n" | ||||
|         ) | ||||
|         # Get the resulting message off of the channel layer | ||||
|         _, message = self.channel_layer.receive_many(["http.request"]) | ||||
|         self.assertEqual(message['client'], ['10.1.2.3', 0]) | ||||
|  |  | |||
							
								
								
									
										61
									
								
								daphne/tests/test_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								daphne/tests/test_utils.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,61 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| from unittest import TestCase | ||||
| 
 | ||||
| from twisted.web.http_headers import Headers | ||||
| 
 | ||||
| from ..utils import parse_x_forwarded_for | ||||
| 
 | ||||
| 
 | ||||
| class TestXForwardedForParsing(TestCase): | ||||
|     """ | ||||
|     Tests that the parse_x_forwarded_for util correcly parses headers. | ||||
|     """ | ||||
| 
 | ||||
|     def test_basic(self): | ||||
|         headers = Headers({ | ||||
|             b'X-Forwarded-For': [b'10.1.2.3'], | ||||
|             b'X-Forwarded-Port': [b'1234'] | ||||
|         }) | ||||
|         self.assertEqual( | ||||
|             parse_x_forwarded_for(headers), | ||||
|             ['10.1.2.3', 1234] | ||||
|         ) | ||||
| 
 | ||||
|     def test_address_only(self): | ||||
|         headers = Headers({ | ||||
|             b'X-Forwarded-For': [b'10.1.2.3'], | ||||
|         }) | ||||
|         self.assertEqual( | ||||
|             parse_x_forwarded_for(headers), | ||||
|             ['10.1.2.3', 0] | ||||
|         ) | ||||
| 
 | ||||
|     def test_port_in_address(self): | ||||
|         headers = Headers({ | ||||
|             b'X-Forwarded-For': [b'10.1.2.3:5123'], | ||||
|         }) | ||||
|         self.assertEqual( | ||||
|             parse_x_forwarded_for(headers), | ||||
|             ['10.1.2.3', 5123] | ||||
|         ) | ||||
| 
 | ||||
|     def test_multiple_proxys(self): | ||||
|         headers = Headers({ | ||||
|             b'X-Forwarded-For': [b'10.1.2.3, 10.1.2.4'], | ||||
|         }) | ||||
|         self.assertEqual( | ||||
|             parse_x_forwarded_for(headers), | ||||
|             ['10.1.2.4', 0] | ||||
|         ) | ||||
| 
 | ||||
|     def test_original(self): | ||||
|         headers = Headers({}) | ||||
|         self.assertEqual( | ||||
|             parse_x_forwarded_for(headers, original=['127.0.0.1', 80]), | ||||
|             ['127.0.0.1', 80] | ||||
|         ) | ||||
| 
 | ||||
|     def test_no_original(self): | ||||
|         headers = Headers({}) | ||||
|         self.assertIsNone(parse_x_forwarded_for(headers)) | ||||
							
								
								
									
										48
									
								
								daphne/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								daphne/utils.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,48 @@ | |||
| 
 | ||||
| 
 | ||||
| def parse_x_forwarded_for(headers, | ||||
|                           address_header_name='X-Forwarded-For', | ||||
|                           port_header_name='X-Forwarded-Port', | ||||
|                           original=None): | ||||
|     """ | ||||
|     Parses an X-Forwarded-For header and returns a host/port pair as a list. | ||||
| 
 | ||||
|     @param headers: The twisted-style object containing a request's headers | ||||
|     @param address_header_name: The name of the expected host header | ||||
|     @param port_header_name: The name of the expected port header | ||||
|     @param original: A host/port pair that should be returned if the headers are not in the request | ||||
|     @return: A list containing a host (string) as the first entry and a port (int) as the second. | ||||
|     """ | ||||
|     if not address_header_name: | ||||
|         return original | ||||
| 
 | ||||
|     address_header_name = address_header_name.lower().encode("utf-8") | ||||
|     result = original | ||||
|     if headers.hasHeader(address_header_name): | ||||
|         address_value = headers.getRawHeaders(address_header_name)[0].decode("utf-8") | ||||
| 
 | ||||
|         if ',' in address_value: | ||||
|             address_value = address_value.split(",")[-1].strip() | ||||
| 
 | ||||
|         if ':' in address_value: | ||||
|             address_host, address_port = address_value.split(':') | ||||
|             result = [address_host, 0] | ||||
|             try: | ||||
|                 result[1] = int(address_port) | ||||
|             except ValueError: | ||||
|                 pass | ||||
|         else: | ||||
|             result = [address_value, 0] | ||||
| 
 | ||||
|         if port_header_name: | ||||
|             # We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For | ||||
|             # header to avoid inconsistent results. | ||||
|             port_header_name = port_header_name.lower().encode("utf-8") | ||||
|             if headers.hasHeader(port_header_name): | ||||
|                 port_value = headers.getRawHeaders(port_header_name)[0].decode("utf-8") | ||||
|                 try: | ||||
|                     result[1] = int(port_value) | ||||
|                 except ValueError: | ||||
|                     pass | ||||
| 
 | ||||
|     return result | ||||
|  | @ -9,6 +9,8 @@ from twisted.internet import defer | |||
| 
 | ||||
| from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory, ConnectionDeny | ||||
| 
 | ||||
| from .utils import parse_x_forwarded_for | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -54,6 +56,15 @@ class WebSocketProtocol(WebSocketServerProtocol): | |||
|             else: | ||||
|                 self.client_addr = None | ||||
|                 self.server_addr = None | ||||
| 
 | ||||
|             if self.factory.proxy_forwarded_address_header: | ||||
|                 self.client_addr = parse_x_forwarded_for( | ||||
|                     self.requestHeaders, | ||||
|                     self.main_factory.proxy_forwarded_address_header, | ||||
|                     self.main_factory.proxy_forwarded_port_header, | ||||
|                     self.client_addr | ||||
|                 ) | ||||
| 
 | ||||
|             # Make initial request info dict from request (we only have it here) | ||||
|             self.path = request.path.encode("ascii") | ||||
|             self.request_info = { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user