mirror of
				https://github.com/django/daphne.git
				synced 2025-11-04 01:27:33 +03:00 
			
		
		
		
	HTTP protocol tests
This commit is contained in:
		
							parent
							
								
									0626f39214
								
							
						
					
					
						commit
						b72349d2c1
					
				| 
						 | 
					@ -127,12 +127,7 @@ class WebRequest(http.Request):
 | 
				
			||||||
                # Remove our HTTP reply channel association
 | 
					                # Remove our HTTP reply channel association
 | 
				
			||||||
                logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
 | 
					                logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
 | 
				
			||||||
                # Resume the producer so we keep getting data, if it's available as a method
 | 
					                # Resume the producer so we keep getting data, if it's available as a method
 | 
				
			||||||
                # 17.1 version
 | 
					 | 
				
			||||||
                if hasattr(self.channel, "_networkProducer"):
 | 
					 | 
				
			||||||
                self.channel._networkProducer.resumeProducing()
 | 
					                self.channel._networkProducer.resumeProducing()
 | 
				
			||||||
                # 16.x version
 | 
					 | 
				
			||||||
                elif hasattr(self.channel, "resumeProducing"):
 | 
					 | 
				
			||||||
                    self.channel.resumeProducing()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Boring old HTTP.
 | 
					            # Boring old HTTP.
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -106,6 +106,12 @@ class Server(object):
 | 
				
			||||||
        reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications)
 | 
					        reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications)
 | 
				
			||||||
        reactor.run(installSignalHandlers=self.signal_handlers)
 | 
					        reactor.run(installSignalHandlers=self.signal_handlers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def stop(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Force-stops the server.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        reactor.stop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ### Protocol handling
 | 
					    ### Protocol handling
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_protocol(self, protocol):
 | 
					    def add_protocol(self, protocol):
 | 
				
			||||||
| 
						 | 
					@ -159,6 +165,10 @@ class Server(object):
 | 
				
			||||||
            if application_instance.done():
 | 
					            if application_instance.done():
 | 
				
			||||||
                exception = application_instance.exception()
 | 
					                exception = application_instance.exception()
 | 
				
			||||||
                if exception:
 | 
					                if exception:
 | 
				
			||||||
 | 
					                    if isinstance(exception, KeyboardInterrupt):
 | 
				
			||||||
 | 
					                        # Protocol is asking the server to exit (likely during test)
 | 
				
			||||||
 | 
					                        self.stop()
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
                        logging.error(
 | 
					                        logging.error(
 | 
				
			||||||
                            "Exception inside application: {}\n{}{}".format(
 | 
					                            "Exception inside application: {}\n{}{}".format(
 | 
				
			||||||
                                exception,
 | 
					                                exception,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										86
									
								
								daphne/test_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								daphne/test_utils.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,86 @@
 | 
				
			||||||
 | 
					import msgpack
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import tempfile
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestApplication:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    An application that receives one or more messages, sends a response,
 | 
				
			||||||
 | 
					    and then quits the server. For testing.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    setup_storage = os.path.join(tempfile.gettempdir(), "setup.testio")
 | 
				
			||||||
 | 
					    result_storage = os.path.join(tempfile.gettempdir(), "result.testio")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, scope):
 | 
				
			||||||
 | 
					        self.scope = scope
 | 
				
			||||||
 | 
					        self.messages = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def __call__(self, send, receive):
 | 
				
			||||||
 | 
					        # Load setup info
 | 
				
			||||||
 | 
					        setup = self.load_setup()
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            for _ in range(setup["receive_messages"]):
 | 
				
			||||||
 | 
					                self.messages.append(await receive())
 | 
				
			||||||
 | 
					            for message in setup["response_messages"]:
 | 
				
			||||||
 | 
					                await send(message)
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            self.save_result()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def save_setup(cls, response_messages, receive_messages=1):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Stores setup information.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        with open(cls.setup_storage, "wb") as fh:
 | 
				
			||||||
 | 
					            fh.write(msgpack.packb(
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "response_messages": response_messages,
 | 
				
			||||||
 | 
					                    "receive_messages": receive_messages,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					                use_bin_type=True,
 | 
				
			||||||
 | 
					            ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def load_setup(cls):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Returns setup details.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        with open(cls.setup_storage, "rb") as fh:
 | 
				
			||||||
 | 
					            return msgpack.unpackb(fh.read(), encoding="utf-8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def save_result(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Saves details of what happened to the result storage.
 | 
				
			||||||
 | 
					        We could use pickle here, but that seems wrong, still, somehow.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        with open(self.result_storage, "wb") as fh:
 | 
				
			||||||
 | 
					            fh.write(msgpack.packb(
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "scope": self.scope,
 | 
				
			||||||
 | 
					                    "messages": self.messages,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					                use_bin_type=True,
 | 
				
			||||||
 | 
					            ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def load_result(cls):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Returns result details.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        with open(cls.result_storage, "rb") as fh:
 | 
				
			||||||
 | 
					            return msgpack.unpackb(fh.read(), encoding="utf-8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def clear_storage(cls):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Clears storage files.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            os.unlink(cls.setup_storage)
 | 
				
			||||||
 | 
					        except OSError:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            os.unlink(cls.result_storage)
 | 
				
			||||||
 | 
					        except OSError:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
| 
						 | 
					@ -1,197 +0,0 @@
 | 
				
			||||||
# coding: utf8
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
Tests for the HTTP request section of the ASGI spec
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
from __future__ import unicode_literals
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import unittest
 | 
					 | 
				
			||||||
from six.moves.urllib import parse
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from asgiref.inmemory import ChannelLayer
 | 
					 | 
				
			||||||
from hypothesis import given, assume, settings, HealthCheck
 | 
					 | 
				
			||||||
from twisted.test import proto_helpers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from daphne.http_protocol import HTTPFactory
 | 
					 | 
				
			||||||
from daphne.tests import testcases, http_strategies
 | 
					 | 
				
			||||||
from daphne.tests.factories import message_for_request, content_length_header
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestHTTPRequestSpec(testcases.ASGIHTTPTestCase):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Tests which try to pour the HTTP request section of the ASGI spec into code.
 | 
					 | 
				
			||||||
    The heavy lifting is done by the assert_valid_http_request_message function,
 | 
					 | 
				
			||||||
    the tests mostly serve to wire up hypothesis so that it exercise it's power to find
 | 
					 | 
				
			||||||
    edge cases.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_minimal_request(self):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Smallest viable example. Mostly verifies that our request building works.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        request_method, request_path = "GET", "/"
 | 
					 | 
				
			||||||
        message = message_for_request(request_method, request_path)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assert_valid_http_request_message(message, request_method, request_path)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @given(
 | 
					 | 
				
			||||||
        request_path=http_strategies.http_path(),
 | 
					 | 
				
			||||||
        request_params=http_strategies.query_params()
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    def test_get_request(self, request_path, request_params):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Tests a typical HTTP GET request, with a path and query parameters
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        request_method = "GET"
 | 
					 | 
				
			||||||
        message = message_for_request(request_method, request_path, request_params)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assert_valid_http_request_message(
 | 
					 | 
				
			||||||
            message, request_method, request_path, request_params=request_params)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @given(
 | 
					 | 
				
			||||||
        request_path=http_strategies.http_path(),
 | 
					 | 
				
			||||||
        request_body=http_strategies.http_body()
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    def test_post_request(self, request_path, request_body):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Tests a typical POST request, submitting some data in a body.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        request_method = "POST"
 | 
					 | 
				
			||||||
        headers = [content_length_header(request_body)]
 | 
					 | 
				
			||||||
        message = message_for_request(
 | 
					 | 
				
			||||||
            request_method, request_path, headers=headers, body=request_body)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assert_valid_http_request_message(
 | 
					 | 
				
			||||||
            message, request_method, request_path,
 | 
					 | 
				
			||||||
            request_headers=headers, request_body=request_body)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @given(request_headers=http_strategies.headers())
 | 
					 | 
				
			||||||
    def test_headers(self, request_headers):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Tests that HTTP header fields are handled as specified
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        request_method, request_path = "OPTIONS", "/te st-à/"
 | 
					 | 
				
			||||||
        message = message_for_request(request_method, request_path, headers=request_headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assert_valid_http_request_message(
 | 
					 | 
				
			||||||
            message, request_method, request_path, request_headers=request_headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @given(request_headers=http_strategies.headers())
 | 
					 | 
				
			||||||
    def test_duplicate_headers(self, request_headers):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Tests that duplicate header values are preserved
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        assume(len(request_headers) >= 2)
 | 
					 | 
				
			||||||
        # Set all header field names to the same value
 | 
					 | 
				
			||||||
        header_name = request_headers[0][0]
 | 
					 | 
				
			||||||
        duplicated_headers = [(header_name, header[1]) for header in request_headers]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        request_method, request_path = "OPTIONS", "/te st-à/"
 | 
					 | 
				
			||||||
        message = message_for_request(request_method, request_path, headers=duplicated_headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assert_valid_http_request_message(
 | 
					 | 
				
			||||||
            message, request_method, request_path, request_headers=duplicated_headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @given(
 | 
					 | 
				
			||||||
        request_method=http_strategies.http_method(),
 | 
					 | 
				
			||||||
        request_path=http_strategies.http_path(),
 | 
					 | 
				
			||||||
        request_params=http_strategies.query_params(),
 | 
					 | 
				
			||||||
        request_headers=http_strategies.headers(),
 | 
					 | 
				
			||||||
        request_body=http_strategies.http_body(),
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    # This test is slow enough that on Travis, hypothesis sometimes complains.
 | 
					 | 
				
			||||||
    @settings(suppress_health_check=[HealthCheck.too_slow])
 | 
					 | 
				
			||||||
    def test_kitchen_sink(
 | 
					 | 
				
			||||||
            self, request_method, request_path, request_params, request_headers, request_body):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Throw everything at channels that we dare. The idea is that if a combination
 | 
					 | 
				
			||||||
        of method/path/headers/body would break the spec, hypothesis will eventually find it.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        request_headers.append(content_length_header(request_body))
 | 
					 | 
				
			||||||
        message = message_for_request(
 | 
					 | 
				
			||||||
            request_method, request_path, request_params, request_headers, request_body)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assert_valid_http_request_message(
 | 
					 | 
				
			||||||
            message, request_method, request_path, request_params, request_headers, request_body)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_headers_are_lowercased_and_stripped(self):
 | 
					 | 
				
			||||||
        request_method, request_path = "GET", "/"
 | 
					 | 
				
			||||||
        headers = [("MYCUSTOMHEADER", "   foobar    ")]
 | 
					 | 
				
			||||||
        message = message_for_request(request_method, request_path, headers=headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assert_valid_http_request_message(
 | 
					 | 
				
			||||||
            message, request_method, request_path, request_headers=headers)
 | 
					 | 
				
			||||||
        # Note that Daphne returns a list of tuples here, which is fine, because the spec
 | 
					 | 
				
			||||||
        # asks to treat them interchangeably.
 | 
					 | 
				
			||||||
        assert message["headers"] == [(b"mycustomheader", b"foobar")]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @given(daphne_path=http_strategies.http_path())
 | 
					 | 
				
			||||||
    def test_root_path_header(self, daphne_path):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Tests root_path handling.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        request_method, request_path = "GET", "/"
 | 
					 | 
				
			||||||
        # Daphne-Root-Path must be URL encoded when submitting as HTTP header field
 | 
					 | 
				
			||||||
        headers = [("Daphne-Root-Path", parse.quote(daphne_path.encode("utf8")))]
 | 
					 | 
				
			||||||
        message = message_for_request(request_method, request_path, headers=headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Daphne-Root-Path is not included in the returned 'headers' section. So we expect
 | 
					 | 
				
			||||||
        # empty headers.
 | 
					 | 
				
			||||||
        expected_headers = []
 | 
					 | 
				
			||||||
        self.assert_valid_http_request_message(
 | 
					 | 
				
			||||||
            message, request_method, request_path, request_headers=expected_headers)
 | 
					 | 
				
			||||||
        # And what we're looking for, root_path being set.
 | 
					 | 
				
			||||||
        assert message["root_path"] == daphne_path
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestProxyHandling(unittest.TestCase):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Tests that concern interaction of Daphne with proxies.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    They live in a separate test case, because they're not part of the spec.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        self.channel_layer = ChannelLayer()
 | 
					 | 
				
			||||||
        self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
 | 
					 | 
				
			||||||
        self.proto = self.factory.buildProtocol(("127.0.0.1", 0))
 | 
					 | 
				
			||||||
        self.tr = proto_helpers.StringTransport()
 | 
					 | 
				
			||||||
        self.proto.makeConnection(self.tr)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_x_forwarded_for_ignored(self):
 | 
					 | 
				
			||||||
        self.proto.dataReceived(
 | 
					 | 
				
			||||||
            b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
 | 
					 | 
				
			||||||
            b"Host: somewhere.com\r\n" +
 | 
					 | 
				
			||||||
            b"X-Forwarded-For: 10.1.2.3\r\n" +
 | 
					 | 
				
			||||||
            b"X-Forwarded-Port: 80\r\n" +
 | 
					 | 
				
			||||||
            b"\r\n"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        # Get the resulting message off of the channel layer
 | 
					 | 
				
			||||||
        _, message = self.channel_layer.receive(["http.request"])
 | 
					 | 
				
			||||||
        self.assertEqual(message["client"], ["192.168.1.1", 54321])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_x_forwarded_for_parsed(self):
 | 
					 | 
				
			||||||
        self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
 | 
					 | 
				
			||||||
        self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
 | 
					 | 
				
			||||||
        self.proto.dataReceived(
 | 
					 | 
				
			||||||
            b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
 | 
					 | 
				
			||||||
            b"Host: somewhere.com\r\n" +
 | 
					 | 
				
			||||||
            b"X-Forwarded-For: 10.1.2.3\r\n" +
 | 
					 | 
				
			||||||
            b"X-Forwarded-Port: 80\r\n" +
 | 
					 | 
				
			||||||
            b"\r\n"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        # Get the resulting message off of the channel layer
 | 
					 | 
				
			||||||
        _, message = self.channel_layer.receive(["http.request"])
 | 
					 | 
				
			||||||
        self.assertEqual(message["client"], ["10.1.2.3", 80])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_x_forwarded_for_port_missing(self):
 | 
					 | 
				
			||||||
        self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
 | 
					 | 
				
			||||||
        self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
 | 
					 | 
				
			||||||
        self.proto.dataReceived(
 | 
					 | 
				
			||||||
            b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
 | 
					 | 
				
			||||||
            b"Host: somewhere.com\r\n" +
 | 
					 | 
				
			||||||
            b"X-Forwarded-For: 10.1.2.3\r\n" +
 | 
					 | 
				
			||||||
            b"\r\n"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        # Get the resulting message off of the channel layer
 | 
					 | 
				
			||||||
        _, message = self.channel_layer.receive(["http.request"])
 | 
					 | 
				
			||||||
        self.assertEqual(message["client"], ["10.1.2.3", 0])
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,238 +0,0 @@
 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
Contains a test case class to allow verifying ASGI messages
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
from __future__ import unicode_literals
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from collections import defaultdict
 | 
					 | 
				
			||||||
from urllib import parse
 | 
					 | 
				
			||||||
import socket
 | 
					 | 
				
			||||||
import unittest
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from . import factories
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ASGITestCaseBase(unittest.TestCase):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Base class for our test classes which contains shared method.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def assert_is_ip_address(self, address):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Tests whether a given address string is a valid IPv4 or IPv6 address.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            socket.inet_aton(address)
 | 
					 | 
				
			||||||
        except socket.error:
 | 
					 | 
				
			||||||
            self.fail("'%s' is not a valid IP address." % address)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def assert_presence_of_message_keys(self, keys, required_keys, optional_keys):
 | 
					 | 
				
			||||||
        present_keys = set(keys)
 | 
					 | 
				
			||||||
        self.assertTrue(required_keys <= present_keys)
 | 
					 | 
				
			||||||
        # Assert that no other keys are present
 | 
					 | 
				
			||||||
        self.assertEqual(set(), present_keys - required_keys - optional_keys)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def assert_valid_reply_channel(self, reply_channel):
 | 
					 | 
				
			||||||
        self.assertIsInstance(reply_channel, str)
 | 
					 | 
				
			||||||
        # The reply channel is decided by the server.
 | 
					 | 
				
			||||||
        self.assertTrue(reply_channel.startswith("test!"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def assert_valid_path(self, path, request_path):
 | 
					 | 
				
			||||||
        self.assertIsInstance(path, str)
 | 
					 | 
				
			||||||
        self.assertEqual(path, request_path)
 | 
					 | 
				
			||||||
        # Assert that it's already url decoded
 | 
					 | 
				
			||||||
        self.assertEqual(path, parse.unquote(path))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def assert_valid_address_and_port(self, host):
 | 
					 | 
				
			||||||
        address, port = host
 | 
					 | 
				
			||||||
        self.assertIsInstance(address, str)
 | 
					 | 
				
			||||||
        self.assert_is_ip_address(address)
 | 
					 | 
				
			||||||
        self.assertIsInstance(port, int)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ASGIHTTPTestCase(ASGITestCaseBase):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Test case with helpers for verifying HTTP channel messages
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def assert_valid_http_request_message(
 | 
					 | 
				
			||||||
            self, channel_message, request_method, request_path,
 | 
					 | 
				
			||||||
            request_params=None, request_headers=None, request_body=None):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertTrue(channel_message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assert_presence_of_message_keys(
 | 
					 | 
				
			||||||
            channel_message.keys(),
 | 
					 | 
				
			||||||
            {"reply_channel", "http_version", "method", "path", "query_string", "headers"},
 | 
					 | 
				
			||||||
            {"scheme", "root_path", "body", "body_channel", "client", "server"})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # == Assertions about required channel_message fields ==
 | 
					 | 
				
			||||||
        self.assert_valid_reply_channel(channel_message["reply_channel"])
 | 
					 | 
				
			||||||
        self.assert_valid_path(channel_message["path"], request_path)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        http_version = channel_message["http_version"]
 | 
					 | 
				
			||||||
        self.assertIsInstance(http_version, str)
 | 
					 | 
				
			||||||
        self.assertIn(http_version, ["1.0", "1.1", "1.2"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        method = channel_message["method"]
 | 
					 | 
				
			||||||
        self.assertIsInstance(method, str)
 | 
					 | 
				
			||||||
        self.assertTrue(method.isupper())
 | 
					 | 
				
			||||||
        self.assertEqual(channel_message["method"], request_method)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        query_string = channel_message["query_string"]
 | 
					 | 
				
			||||||
        # Assert that query_string is a byte string and still url encoded
 | 
					 | 
				
			||||||
        self.assertIsInstance(query_string, bytes)
 | 
					 | 
				
			||||||
        self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Ordering of header names is not important, but the order of values for a header
 | 
					 | 
				
			||||||
        # name is. To assert whether that order is kept, we transform both the request
 | 
					 | 
				
			||||||
        # headers and the channel message headers into a dictionary
 | 
					 | 
				
			||||||
        # {name: [value1, value2, ...]} and check if they're equal.
 | 
					 | 
				
			||||||
        transformed_message_headers = defaultdict(list)
 | 
					 | 
				
			||||||
        for name, value in channel_message["headers"]:
 | 
					 | 
				
			||||||
            transformed_message_headers[name].append(value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        transformed_request_headers = defaultdict(list)
 | 
					 | 
				
			||||||
        for name, value in (request_headers or []):
 | 
					 | 
				
			||||||
            expected_name = name.lower().strip().encode("ascii")
 | 
					 | 
				
			||||||
            expected_value = value.strip().encode("ascii")
 | 
					 | 
				
			||||||
            transformed_request_headers[expected_name].append(expected_value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertEqual(transformed_message_headers, transformed_request_headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # == Assertions about optional channel_message fields ==
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        scheme = channel_message.get("scheme")
 | 
					 | 
				
			||||||
        if scheme is not None:
 | 
					 | 
				
			||||||
            self.assertIsInstance(scheme, str)
 | 
					 | 
				
			||||||
            self.assertTrue(scheme)  # May not be empty
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root_path = channel_message.get("root_path")
 | 
					 | 
				
			||||||
        if root_path is not None:
 | 
					 | 
				
			||||||
            self.assertIsInstance(root_path, str)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        body = channel_message.get("body")
 | 
					 | 
				
			||||||
        # Ensure we test for presence of 'body' if a request body was given
 | 
					 | 
				
			||||||
        if request_body is not None or body is not None:
 | 
					 | 
				
			||||||
            self.assertIsInstance(body, str)
 | 
					 | 
				
			||||||
            self.assertEqual(body, (request_body or "").encode("ascii"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        body_channel = channel_message.get("body_channel")
 | 
					 | 
				
			||||||
        if body_channel is not None:
 | 
					 | 
				
			||||||
            self.assertIsInstance(body_channel, str)
 | 
					 | 
				
			||||||
            self.assertIn("?", body_channel)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client = channel_message.get("client")
 | 
					 | 
				
			||||||
        if client is not None:
 | 
					 | 
				
			||||||
            self.assert_valid_address_and_port(channel_message["client"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        server = channel_message.get("server")
 | 
					 | 
				
			||||||
        if server is not None:
 | 
					 | 
				
			||||||
            self.assert_valid_address_and_port(channel_message["server"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def assert_valid_http_response_message(self, message, response):
 | 
					 | 
				
			||||||
        self.assertTrue(message)
 | 
					 | 
				
			||||||
        self.assertTrue(response.startswith(b"HTTP"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        status_code_bytes = str(message["status"]).encode("ascii")
 | 
					 | 
				
			||||||
        self.assertIn(status_code_bytes, response)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if "content" in message:
 | 
					 | 
				
			||||||
            self.assertIn(message["content"], response)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Check that headers are in the given order.
 | 
					 | 
				
			||||||
        # N.b. HTTP spec only enforces that the order of header values is kept, but
 | 
					 | 
				
			||||||
        # the ASGI spec requires that order of all headers is kept. This code
 | 
					 | 
				
			||||||
        # checks conformance with the stricter ASGI spec.
 | 
					 | 
				
			||||||
        if "headers" in message:
 | 
					 | 
				
			||||||
            for name, value in message["headers"]:
 | 
					 | 
				
			||||||
                expected_header = factories.header_line(name, value)
 | 
					 | 
				
			||||||
                # Daphne or Twisted turn our lower cased header names ('foo-bar') into title
 | 
					 | 
				
			||||||
                # case ('Foo-Bar'). So technically we want to to match that the header name is
 | 
					 | 
				
			||||||
                # present while ignoring casing, and want to ensure the value is present without
 | 
					 | 
				
			||||||
                # altered casing. The approach below does this well enough.
 | 
					 | 
				
			||||||
                self.assertIn(expected_header.lower(), response.lower())
 | 
					 | 
				
			||||||
                self.assertIn(value.encode("ascii"), response)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ASGIWebSocketTestCase(ASGITestCaseBase):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Test case with helpers for verifying WebSocket channel messages
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
 | 
					 | 
				
			||||||
        self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
 | 
					 | 
				
			||||||
        self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
 | 
					 | 
				
			||||||
        self.assertIn(body, response)
 | 
					 | 
				
			||||||
        self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def assert_websocket_denied(self, response):
 | 
					 | 
				
			||||||
        self.assertIn(b"HTTP/1.1 403", response)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def assert_valid_websocket_connect_message(
 | 
					 | 
				
			||||||
            self, channel_message, request_path="/", request_params=None, request_headers=None):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertTrue(channel_message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assert_presence_of_message_keys(
 | 
					 | 
				
			||||||
            channel_message.keys(),
 | 
					 | 
				
			||||||
            {"reply_channel", "path", "headers", "order"},
 | 
					 | 
				
			||||||
            {"scheme", "query_string", "root_path", "client", "server"})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # == Assertions about required channel_message fields ==
 | 
					 | 
				
			||||||
        self.assert_valid_reply_channel(channel_message["reply_channel"])
 | 
					 | 
				
			||||||
        self.assert_valid_path(channel_message["path"], request_path)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        order = channel_message["order"]
 | 
					 | 
				
			||||||
        self.assertIsInstance(order, int)
 | 
					 | 
				
			||||||
        self.assertEqual(order, 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Ordering of header names is not important, but the order of values for a header
 | 
					 | 
				
			||||||
        # name is. To assert whether that order is kept, we transform the request
 | 
					 | 
				
			||||||
        # headers and the channel message headers into a set
 | 
					 | 
				
			||||||
        # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
 | 
					 | 
				
			||||||
        # Note that unlike for HTTP, Daphne never gives out individual header values; instead we
 | 
					 | 
				
			||||||
        # get one string per header field with values separated by comma.
 | 
					 | 
				
			||||||
        transformed_request_headers = defaultdict(list)
 | 
					 | 
				
			||||||
        for name, value in (request_headers or []):
 | 
					 | 
				
			||||||
            expected_name = name.lower().strip().encode("ascii")
 | 
					 | 
				
			||||||
            expected_value = value.strip().encode("ascii")
 | 
					 | 
				
			||||||
            transformed_request_headers[expected_name].append(expected_value)
 | 
					 | 
				
			||||||
        final_request_headers = {
 | 
					 | 
				
			||||||
            (name, b",".join(value)) for name, value in transformed_request_headers.items()
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Websockets carry a lot of additional header fields, so instead of verifying that
 | 
					 | 
				
			||||||
        # headers look exactly like expected, we just check that the expected header fields
 | 
					 | 
				
			||||||
        # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
 | 
					 | 
				
			||||||
        # and not tested for.
 | 
					 | 
				
			||||||
        assert final_request_headers.issubset(set(channel_message["headers"]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # == Assertions about optional channel_message fields ==
 | 
					 | 
				
			||||||
        scheme = channel_message.get("scheme")
 | 
					 | 
				
			||||||
        if scheme:
 | 
					 | 
				
			||||||
            self.assertIsInstance(scheme, six.text_type)
 | 
					 | 
				
			||||||
            self.assertIn(scheme, ["ws", "wss"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        query_string = channel_message.get("query_string")
 | 
					 | 
				
			||||||
        if query_string:
 | 
					 | 
				
			||||||
            # Assert that query_string is a byte string and still url encoded
 | 
					 | 
				
			||||||
            self.assertIsInstance(query_string, six.binary_type)
 | 
					 | 
				
			||||||
            self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root_path = channel_message.get("root_path")
 | 
					 | 
				
			||||||
        if root_path is not None:
 | 
					 | 
				
			||||||
            self.assertIsInstance(root_path, six.text_type)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client = channel_message.get("client")
 | 
					 | 
				
			||||||
        if client is not None:
 | 
					 | 
				
			||||||
            self.assert_valid_address_and_port(channel_message["client"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        server = channel_message.get("server")
 | 
					 | 
				
			||||||
        if server is not None:
 | 
					 | 
				
			||||||
            self.assert_valid_address_and_port(channel_message["server"])
 | 
					 | 
				
			||||||
| 
						 | 
					@ -3,12 +3,3 @@ universal=1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[tool:pytest]
 | 
					[tool:pytest]
 | 
				
			||||||
addopts = tests/
 | 
					addopts = tests/
 | 
				
			||||||
 | 
					 | 
				
			||||||
[yapf]
 | 
					 | 
				
			||||||
based_on_style = pep8
 | 
					 | 
				
			||||||
column_limit = 120
 | 
					 | 
				
			||||||
join_multiple_lines = false
 | 
					 | 
				
			||||||
split_arguments_when_comma_terminated = true
 | 
					 | 
				
			||||||
split_before_expression_after_opening_paren = true
 | 
					 | 
				
			||||||
split_before_first_argument = true
 | 
					 | 
				
			||||||
split_penalty_after_opening_bracket = -10
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										248
									
								
								tests/http_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										248
									
								
								tests/http_base.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,248 @@
 | 
				
			||||||
 | 
					from urllib import parse
 | 
				
			||||||
 | 
					import requests
 | 
				
			||||||
 | 
					import socket
 | 
				
			||||||
 | 
					import subprocess
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from daphne.test_utils import TestApplication
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DaphneTestCase(unittest.TestCase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Base class for Daphne integration test cases.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Boots up a copy of Daphne on a test port and sends it a request, and
 | 
				
			||||||
 | 
					    retrieves the response. Uses a custom ASGI application and temporary files
 | 
				
			||||||
 | 
					    to store/retrieve the request/response messages.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def port_in_use(self, port):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Tests if a port is in use on the local machine.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            s.bind(("127.0.0.1", port))
 | 
				
			||||||
 | 
					        except socket.error as e:
 | 
				
			||||||
 | 
					            if e.errno in [13, 98]:
 | 
				
			||||||
 | 
					                return True
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            s.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def run_daphne(self, method, path, params, data, responses, headers=None, timeout=1):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Runs Daphne with the given request callback (given the base URL)
 | 
				
			||||||
 | 
					        and response messages.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # Store setup info
 | 
				
			||||||
 | 
					        TestApplication.clear_storage()
 | 
				
			||||||
 | 
					        TestApplication.save_setup(
 | 
				
			||||||
 | 
					            response_messages=responses,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # Find a free port
 | 
				
			||||||
 | 
					        for i in range(11200, 11300):
 | 
				
			||||||
 | 
					            if not self.port_in_use(i):
 | 
				
			||||||
 | 
					                port = i
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise RuntimeError("Cannot find a free port to test on")
 | 
				
			||||||
 | 
					        # Launch daphne on that port
 | 
				
			||||||
 | 
					        process = subprocess.Popen(["daphne", "-p", str(port), "daphne.test_utils:TestApplication"])
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            for _ in range(100):
 | 
				
			||||||
 | 
					                time.sleep(0.1)
 | 
				
			||||||
 | 
					                if self.port_in_use(port):
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise RuntimeError("Daphne never came up.")
 | 
				
			||||||
 | 
					            # Send it the request
 | 
				
			||||||
 | 
					            url = "http://127.0.0.1:%i%s" % (port, path)
 | 
				
			||||||
 | 
					            response = requests.request(method, url, params=params, data=data, headers=headers, timeout=timeout)
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            # Shut down daphne
 | 
				
			||||||
 | 
					            process.terminate()
 | 
				
			||||||
 | 
					        # Load the information
 | 
				
			||||||
 | 
					        inner_result = TestApplication.load_result()
 | 
				
			||||||
 | 
					        # Return the inner result and the response
 | 
				
			||||||
 | 
					        return inner_result, response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def run_daphne_request(self, method, path, params=None, data=None, headers=None):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Convenience method for just testing request handling.
 | 
				
			||||||
 | 
					        Returns (scope, messages)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if headers is not None:
 | 
				
			||||||
 | 
					            headers = dict(headers)
 | 
				
			||||||
 | 
					        inner_result, _ = self.run_daphne(
 | 
				
			||||||
 | 
					            method=method,
 | 
				
			||||||
 | 
					            path=path,
 | 
				
			||||||
 | 
					            params=params,
 | 
				
			||||||
 | 
					            data=data,
 | 
				
			||||||
 | 
					            headers=headers,
 | 
				
			||||||
 | 
					            responses=[{"type": "http.response", "status": 200, "content": b"OK"}],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return inner_result["scope"], inner_result["messages"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Ensures any storage files are cleared.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        TestApplication.clear_storage()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_is_ip_address(self, address):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Tests whether a given address string is a valid IPv4 or IPv6 address.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            socket.inet_aton(address)
 | 
				
			||||||
 | 
					        except socket.error:
 | 
				
			||||||
 | 
					            self.fail("'%s' is not a valid IP address." % address)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_key_sets(self, required_keys, optional_keys, actual_keys):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Asserts that all required_keys are in actual_keys, and that there
 | 
				
			||||||
 | 
					        are no keys in actual_keys that aren't required or optional.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        present_keys = set(actual_keys)
 | 
				
			||||||
 | 
					        # Make sure all required keys are present
 | 
				
			||||||
 | 
					        self.assertTrue(required_keys <= present_keys)
 | 
				
			||||||
 | 
					        # Assert that no other keys are present
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            set(),
 | 
				
			||||||
 | 
					            present_keys - required_keys - optional_keys,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_valid_path(self, path, request_path):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Checks the path is valid and already url-decoded.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.assertIsInstance(path, str)
 | 
				
			||||||
 | 
					        self.assertEqual(path, request_path)
 | 
				
			||||||
 | 
					        # Assert that it's already url decoded
 | 
				
			||||||
 | 
					        self.assertEqual(path, parse.unquote(path))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_valid_address_and_port(self, host):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Asserts the value is a valid (host, port) tuple.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        address, port = host
 | 
				
			||||||
 | 
					        self.assertIsInstance(address, str)
 | 
				
			||||||
 | 
					        self.assert_is_ip_address(address)
 | 
				
			||||||
 | 
					        self.assertIsInstance(port, int)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# class ASGIHTTPTestCase(ASGITestCaseBase):
 | 
				
			||||||
 | 
					#     """
 | 
				
			||||||
 | 
					#     Test case with helpers for verifying HTTP channel messages
 | 
				
			||||||
 | 
					#     """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     def assert_valid_http_response_message(self, message, response):
 | 
				
			||||||
 | 
					#         self.assertTrue(message)
 | 
				
			||||||
 | 
					#         self.assertTrue(response.startswith(b"HTTP"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         status_code_bytes = str(message["status"]).encode("ascii")
 | 
				
			||||||
 | 
					#         self.assertIn(status_code_bytes, response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         if "content" in message:
 | 
				
			||||||
 | 
					#             self.assertIn(message["content"], response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         # Check that headers are in the given order.
 | 
				
			||||||
 | 
					#         # N.b. HTTP spec only enforces that the order of header values is kept, but
 | 
				
			||||||
 | 
					#         # the ASGI spec requires that order of all headers is kept. This code
 | 
				
			||||||
 | 
					#         # checks conformance with the stricter ASGI spec.
 | 
				
			||||||
 | 
					#         if "headers" in message:
 | 
				
			||||||
 | 
					#             for name, value in message["headers"]:
 | 
				
			||||||
 | 
					#                 expected_header = factories.header_line(name, value)
 | 
				
			||||||
 | 
					#                 # Daphne or Twisted turn our lower cased header names ('foo-bar') into title
 | 
				
			||||||
 | 
					#                 # case ('Foo-Bar'). So technically we want to to match that the header name is
 | 
				
			||||||
 | 
					#                 # present while ignoring casing, and want to ensure the value is present without
 | 
				
			||||||
 | 
					#                 # altered casing. The approach below does this well enough.
 | 
				
			||||||
 | 
					#                 self.assertIn(expected_header.lower(), response.lower())
 | 
				
			||||||
 | 
					#                 self.assertIn(value.encode("ascii"), response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# class ASGIWebSocketTestCase(ASGITestCaseBase):
 | 
				
			||||||
 | 
					#     """
 | 
				
			||||||
 | 
					#     Test case with helpers for verifying WebSocket channel messages
 | 
				
			||||||
 | 
					#     """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
 | 
				
			||||||
 | 
					#         self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
 | 
				
			||||||
 | 
					#         self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
 | 
				
			||||||
 | 
					#         self.assertIn(body, response)
 | 
				
			||||||
 | 
					#         self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     def assert_websocket_denied(self, response):
 | 
				
			||||||
 | 
					#         self.assertIn(b"HTTP/1.1 403", response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     def assert_valid_websocket_connect_message(
 | 
				
			||||||
 | 
					#             self, channel_message, request_path="/", request_params=None, request_headers=None):
 | 
				
			||||||
 | 
					#         """
 | 
				
			||||||
 | 
					#         Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
 | 
				
			||||||
 | 
					#         """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         self.assertTrue(channel_message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         self.assert_presence_of_message_keys(
 | 
				
			||||||
 | 
					#             channel_message.keys(),
 | 
				
			||||||
 | 
					#             {"reply_channel", "path", "headers", "order"},
 | 
				
			||||||
 | 
					#             {"scheme", "query_string", "root_path", "client", "server"})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         # == Assertions about required channel_message fields ==
 | 
				
			||||||
 | 
					#         self.assert_valid_reply_channel(channel_message["reply_channel"])
 | 
				
			||||||
 | 
					#         self.assert_valid_path(channel_message["path"], request_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         order = channel_message["order"]
 | 
				
			||||||
 | 
					#         self.assertIsInstance(order, int)
 | 
				
			||||||
 | 
					#         self.assertEqual(order, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         # Ordering of header names is not important, but the order of values for a header
 | 
				
			||||||
 | 
					#         # name is. To assert whether that order is kept, we transform the request
 | 
				
			||||||
 | 
					#         # headers and the channel message headers into a set
 | 
				
			||||||
 | 
					#         # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
 | 
				
			||||||
 | 
					#         # Note that unlike for HTTP, Daphne never gives out individual header values; instead we
 | 
				
			||||||
 | 
					#         # get one string per header field with values separated by comma.
 | 
				
			||||||
 | 
					#         transformed_request_headers = defaultdict(list)
 | 
				
			||||||
 | 
					#         for name, value in (request_headers or []):
 | 
				
			||||||
 | 
					#             expected_name = name.lower().strip().encode("ascii")
 | 
				
			||||||
 | 
					#             expected_value = value.strip().encode("ascii")
 | 
				
			||||||
 | 
					#             transformed_request_headers[expected_name].append(expected_value)
 | 
				
			||||||
 | 
					#         final_request_headers = {
 | 
				
			||||||
 | 
					#             (name, b",".join(value)) for name, value in transformed_request_headers.items()
 | 
				
			||||||
 | 
					#         }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         # Websockets carry a lot of additional header fields, so instead of verifying that
 | 
				
			||||||
 | 
					#         # headers look exactly like expected, we just check that the expected header fields
 | 
				
			||||||
 | 
					#         # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
 | 
				
			||||||
 | 
					#         # and not tested for.
 | 
				
			||||||
 | 
					#         assert final_request_headers.issubset(set(channel_message["headers"]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         # == Assertions about optional channel_message fields ==
 | 
				
			||||||
 | 
					#         scheme = channel_message.get("scheme")
 | 
				
			||||||
 | 
					#         if scheme:
 | 
				
			||||||
 | 
					#             self.assertIsInstance(scheme, six.text_type)
 | 
				
			||||||
 | 
					#             self.assertIn(scheme, ["ws", "wss"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         query_string = channel_message.get("query_string")
 | 
				
			||||||
 | 
					#         if query_string:
 | 
				
			||||||
 | 
					#             # Assert that query_string is a byte string and still url encoded
 | 
				
			||||||
 | 
					#             self.assertIsInstance(query_string, six.binary_type)
 | 
				
			||||||
 | 
					#             self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         root_path = channel_message.get("root_path")
 | 
				
			||||||
 | 
					#         if root_path is not None:
 | 
				
			||||||
 | 
					#             self.assertIsInstance(root_path, six.text_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         client = channel_message.get("client")
 | 
				
			||||||
 | 
					#         if client is not None:
 | 
				
			||||||
 | 
					#             self.assert_valid_address_and_port(channel_message["client"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         server = channel_message.get("server")
 | 
				
			||||||
 | 
					#         if server is not None:
 | 
				
			||||||
 | 
					#             self.assert_valid_address_and_port(channel_message["server"])
 | 
				
			||||||
| 
						 | 
					@ -27,18 +27,16 @@ def http_path():
 | 
				
			||||||
    Returns a URL path (not encoded).
 | 
					    Returns a URL path (not encoded).
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    return strategies.lists(
 | 
					    return strategies.lists(
 | 
				
			||||||
        _http_path_portion(), min_size=0, max_size=10).map(lambda s: "/" + "/".join(s))
 | 
					        _http_path_portion(),
 | 
				
			||||||
 | 
					        min_size=0,
 | 
				
			||||||
 | 
					        max_size=10,
 | 
				
			||||||
 | 
					    ).map(lambda s: "/" + "/".join(s))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def http_body():
 | 
					def http_body():
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Returns random printable ASCII characters. This may be exceeding what HTTP allows,
 | 
					    Returns random binary body data.
 | 
				
			||||||
    but seems to not cause an issue so far.
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    return strategies.text(alphabet=string.printable, min_size=0, average_size=600, max_size=1500)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def binary_payload():
 | 
					 | 
				
			||||||
    return strategies.binary(min_size=0, average_size=600, max_size=1500)
 | 
					    return strategies.binary(min_size=0, average_size=600, max_size=1500)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -59,7 +57,11 @@ def valid_bidi(value):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _domain_label():
 | 
					def _domain_label():
 | 
				
			||||||
    return strategies.text(
 | 
					    return strategies.text(
 | 
				
			||||||
        alphabet=letters, min_size=1, average_size=6, max_size=63).filter(valid_bidi)
 | 
					        alphabet=letters,
 | 
				
			||||||
 | 
					        min_size=1,
 | 
				
			||||||
 | 
					        average_size=6,
 | 
				
			||||||
 | 
					        max_size=63,
 | 
				
			||||||
 | 
					    ).filter(valid_bidi)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def international_domain_name():
 | 
					def international_domain_name():
 | 
				
			||||||
| 
						 | 
					@ -67,12 +69,19 @@ def international_domain_name():
 | 
				
			||||||
    Returns a byte string of a domain name, IDNA-encoded.
 | 
					    Returns a byte string of a domain name, IDNA-encoded.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    return strategies.lists(
 | 
					    return strategies.lists(
 | 
				
			||||||
        _domain_label(), min_size=2, average_size=2).map(lambda s: (".".join(s)).encode("idna"))
 | 
					        _domain_label(),
 | 
				
			||||||
 | 
					        min_size=2,
 | 
				
			||||||
 | 
					        average_size=2,
 | 
				
			||||||
 | 
					    ).map(lambda s: (".".join(s)).encode("idna"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _query_param():
 | 
					def _query_param():
 | 
				
			||||||
    return strategies.text(alphabet=letters, min_size=1, average_size=10, max_size=255).\
 | 
					    return strategies.text(
 | 
				
			||||||
        map(lambda s: s.encode("utf8"))
 | 
					        alphabet=letters,
 | 
				
			||||||
 | 
					        min_size=1,
 | 
				
			||||||
 | 
					        average_size=10,
 | 
				
			||||||
 | 
					        max_size=255,
 | 
				
			||||||
 | 
					    ).map(lambda s: s.encode("utf8"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def query_params():
 | 
					def query_params():
 | 
				
			||||||
| 
						 | 
					@ -82,8 +91,10 @@ def query_params():
 | 
				
			||||||
    ensures that the total urlencoded query string is not longer than 1500 characters.
 | 
					    ensures that the total urlencoded query string is not longer than 1500 characters.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    return strategies.lists(
 | 
					    return strategies.lists(
 | 
				
			||||||
        strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5).\
 | 
					        strategies.tuples(_query_param(), _query_param()),
 | 
				
			||||||
        filter(lambda x: len(parse.urlencode(x)) < 1500)
 | 
					        min_size=0,
 | 
				
			||||||
 | 
					        average_size=5,
 | 
				
			||||||
 | 
					    ).filter(lambda x: len(parse.urlencode(x)) < 1500)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def header_name():
 | 
					def header_name():
 | 
				
			||||||
| 
						 | 
					@ -94,7 +105,10 @@ def header_name():
 | 
				
			||||||
    and 20 characters long
 | 
					    and 20 characters long
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    return strategies.text(
 | 
					    return strategies.text(
 | 
				
			||||||
        alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30)
 | 
					        alphabet=string.ascii_letters + string.digits + "-",
 | 
				
			||||||
 | 
					        min_size=1,
 | 
				
			||||||
 | 
					        max_size=30,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def header_value():
 | 
					def header_value():
 | 
				
			||||||
| 
						 | 
					@ -106,7 +120,10 @@ def header_value():
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    return strategies.text(
 | 
					    return strategies.text(
 | 
				
			||||||
        alphabet=string.ascii_letters + string.digits + string.punctuation + " /t",
 | 
					        alphabet=string.ascii_letters + string.digits + string.punctuation + " /t",
 | 
				
			||||||
        min_size=1, average_size=40, max_size=8190).filter(lambda s: len(s.encode("utf8")) < 8190)
 | 
					        min_size=1,
 | 
				
			||||||
 | 
					        average_size=40,
 | 
				
			||||||
 | 
					        max_size=8190,
 | 
				
			||||||
 | 
					    ).filter(lambda s: len(s.encode("utf8")) < 8190)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def headers():
 | 
					def headers():
 | 
				
			||||||
| 
						 | 
					@ -118,4 +135,7 @@ def headers():
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    return strategies.lists(
 | 
					    return strategies.lists(
 | 
				
			||||||
        strategies.tuples(header_name(), header_value()),
 | 
					        strategies.tuples(header_name(), header_value()),
 | 
				
			||||||
        min_size=0, average_size=10, max_size=100)
 | 
					        min_size=0,
 | 
				
			||||||
 | 
					        average_size=10,
 | 
				
			||||||
 | 
					        max_size=100,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
							
								
								
									
										267
									
								
								tests/test_http_request.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										267
									
								
								tests/test_http_request.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,267 @@
 | 
				
			||||||
 | 
					# coding: utf8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import collections
 | 
				
			||||||
 | 
					from urllib import parse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from hypothesis import given, assume, settings, HealthCheck
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import http_strategies
 | 
				
			||||||
 | 
					from http_base import DaphneTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestHTTPRequestSpec(DaphneTestCase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Tests which try to pour the HTTP request section of the ASGI spec into code.
 | 
				
			||||||
 | 
					    The heavy lifting is done by the assert_valid_http_request_message function,
 | 
				
			||||||
 | 
					    the tests mostly serve to wire up hypothesis so that it exercise it's power to find
 | 
				
			||||||
 | 
					    edge cases.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_valid_http_scope(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            scope,
 | 
				
			||||||
 | 
					            method,
 | 
				
			||||||
 | 
					            path,
 | 
				
			||||||
 | 
					            params=None,
 | 
				
			||||||
 | 
					            headers=None,
 | 
				
			||||||
 | 
					            scheme=None,
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Checks that the passed scope is a valid ASGI HTTP scope regarding types
 | 
				
			||||||
 | 
					        and some urlencoding things.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # Check overall keys
 | 
				
			||||||
 | 
					        self.assert_key_sets(
 | 
				
			||||||
 | 
					            required_keys={"type", "http_version", "method", "path", "query_string", "headers"},
 | 
				
			||||||
 | 
					            optional_keys={"scheme", "root_path", "client", "server"},
 | 
				
			||||||
 | 
					            actual_keys=scope.keys(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # Check that it is the right type
 | 
				
			||||||
 | 
					        self.assertEqual(scope["type"], "http")
 | 
				
			||||||
 | 
					        # Method (uppercased unicode string)
 | 
				
			||||||
 | 
					        self.assertIsInstance(scope["method"], str)
 | 
				
			||||||
 | 
					        self.assertEqual(scope["method"], method.upper())
 | 
				
			||||||
 | 
					        # Path
 | 
				
			||||||
 | 
					        self.assert_valid_path(scope["path"], path)
 | 
				
			||||||
 | 
					        # HTTP version
 | 
				
			||||||
 | 
					        self.assertIn(scope["http_version"], ["1.0", "1.1", "1.2"])
 | 
				
			||||||
 | 
					        # Scheme
 | 
				
			||||||
 | 
					        self.assertIn(scope["scheme"], ["http", "https"])
 | 
				
			||||||
 | 
					        if scheme:
 | 
				
			||||||
 | 
					            self.assertEqual(scheme, scope["scheme"])
 | 
				
			||||||
 | 
					        # Query string (byte string and still url encoded)
 | 
				
			||||||
 | 
					        query_string = scope["query_string"]
 | 
				
			||||||
 | 
					        self.assertIsInstance(query_string, bytes)
 | 
				
			||||||
 | 
					        if params:
 | 
				
			||||||
 | 
					            self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
 | 
				
			||||||
 | 
					        # Ordering of header names is not important, but the order of values for a header
 | 
				
			||||||
 | 
					        # name is. To assert whether that order is kept, we transform both the request
 | 
				
			||||||
 | 
					        # headers and the channel message headers into a dictionary
 | 
				
			||||||
 | 
					        # {name: [value1, value2, ...]} and check if they're equal.
 | 
				
			||||||
 | 
					        transformed_scope_headers = collections.defaultdict(list)
 | 
				
			||||||
 | 
					        for name, value in scope["headers"]:
 | 
				
			||||||
 | 
					            transformed_scope_headers[name].append(value)
 | 
				
			||||||
 | 
					        transformed_request_headers = collections.defaultdict(list)
 | 
				
			||||||
 | 
					        for name, value in (headers or []):
 | 
				
			||||||
 | 
					            expected_name = name.lower().strip().encode("ascii")
 | 
				
			||||||
 | 
					            expected_value = value.strip().encode("ascii")
 | 
				
			||||||
 | 
					            transformed_request_headers[expected_name].append(expected_value)
 | 
				
			||||||
 | 
					        for name, value in transformed_request_headers.items():
 | 
				
			||||||
 | 
					            self.assertIn(name, transformed_scope_headers)
 | 
				
			||||||
 | 
					            self.assertEqual(value, transformed_scope_headers[name])
 | 
				
			||||||
 | 
					        # Root path
 | 
				
			||||||
 | 
					        self.assertIsInstance(scope.get("root_path", ""), str)
 | 
				
			||||||
 | 
					        # Client and server addresses
 | 
				
			||||||
 | 
					        client = scope.get("client")
 | 
				
			||||||
 | 
					        if client is not None:
 | 
				
			||||||
 | 
					            self.assert_valid_address_and_port(client)
 | 
				
			||||||
 | 
					        server = scope.get("server")
 | 
				
			||||||
 | 
					        if server is not None:
 | 
				
			||||||
 | 
					            self.assert_valid_address_and_port(server)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_valid_http_request_message(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            message,
 | 
				
			||||||
 | 
					            body=None,
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Asserts that a message is a valid http.request message
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # Check overall keys
 | 
				
			||||||
 | 
					        self.assert_key_sets(
 | 
				
			||||||
 | 
					            required_keys={"type"},
 | 
				
			||||||
 | 
					            optional_keys={"body", "more_content"},
 | 
				
			||||||
 | 
					            actual_keys=message.keys(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # Check that it is the right type
 | 
				
			||||||
 | 
					        self.assertEqual(message["type"], "http.request")
 | 
				
			||||||
 | 
					        # If there's a body present, check its type
 | 
				
			||||||
 | 
					        self.assertIsInstance(message.get("body", b""), bytes)
 | 
				
			||||||
 | 
					        if body is not None:
 | 
				
			||||||
 | 
					            self.assertEqual(body, message.get("body", b""))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_minimal_request(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Smallest viable example. Mostly verifies that our request building works.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        scope, messages = self.run_daphne_request("GET", "/")
 | 
				
			||||||
 | 
					        self.assert_valid_http_scope(scope, "GET", "/")
 | 
				
			||||||
 | 
					        self.assert_valid_http_request_message(messages[0], body=b"")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @given(
 | 
				
			||||||
 | 
					        request_path=http_strategies.http_path(),
 | 
				
			||||||
 | 
					        request_params=http_strategies.query_params()
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    @settings(max_examples=5, deadline=2000)
 | 
				
			||||||
 | 
					    def test_get_request(self, request_path, request_params):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Tests a typical HTTP GET request, with a path and query parameters
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        scope, messages = self.run_daphne_request("GET", request_path, params=request_params)
 | 
				
			||||||
 | 
					        self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
 | 
				
			||||||
 | 
					        self.assert_valid_http_request_message(messages[0], body=b"")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @given(
 | 
				
			||||||
 | 
					        request_path=http_strategies.http_path(),
 | 
				
			||||||
 | 
					        request_body=http_strategies.http_body()
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    @settings(max_examples=5, deadline=2000)
 | 
				
			||||||
 | 
					    def test_post_request(self, request_path, request_body):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Tests a typical HTTP POST request, with a path and body.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        scope, messages = self.run_daphne_request("POST", request_path, data=request_body)
 | 
				
			||||||
 | 
					        self.assert_valid_http_scope(scope, "POST", request_path)
 | 
				
			||||||
 | 
					        self.assert_valid_http_request_message(messages[0], body=request_body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @given(request_headers=http_strategies.headers())
 | 
				
			||||||
 | 
					    @settings(max_examples=5, deadline=2000)
 | 
				
			||||||
 | 
					    def test_headers(self, request_headers):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Tests that HTTP header fields are handled as specified
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        request_path = "/te st-à/"
 | 
				
			||||||
 | 
					        scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers)
 | 
				
			||||||
 | 
					        self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers)
 | 
				
			||||||
 | 
					        self.assert_valid_http_request_message(messages[0], body=b"")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # @given(request_headers=http_strategies.headers())
 | 
				
			||||||
 | 
					    # def test_duplicate_headers(self, request_headers):
 | 
				
			||||||
 | 
					    #     """
 | 
				
			||||||
 | 
					    #     Tests that duplicate header values are preserved
 | 
				
			||||||
 | 
					    #     """
 | 
				
			||||||
 | 
					    #     assume(len(request_headers) >= 2)
 | 
				
			||||||
 | 
					    #     # Set all header field names to the same value
 | 
				
			||||||
 | 
					    #     header_name = request_headers[0][0]
 | 
				
			||||||
 | 
					    #     duplicated_headers = [(header_name, header[1]) for header in request_headers]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     request_method, request_path = "OPTIONS", "/te st-à/"
 | 
				
			||||||
 | 
					    #     message = message_for_request(request_method, request_path, headers=duplicated_headers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     self.assert_valid_http_request_message(
 | 
				
			||||||
 | 
					    #         message, request_method, request_path, request_headers=duplicated_headers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # @given(
 | 
				
			||||||
 | 
					    #     request_method=http_strategies.http_method(),
 | 
				
			||||||
 | 
					    #     request_path=http_strategies.http_path(),
 | 
				
			||||||
 | 
					    #     request_params=http_strategies.query_params(),
 | 
				
			||||||
 | 
					    #     request_headers=http_strategies.headers(),
 | 
				
			||||||
 | 
					    #     request_body=http_strategies.http_body(),
 | 
				
			||||||
 | 
					    # )
 | 
				
			||||||
 | 
					    # # This test is slow enough that on Travis, hypothesis sometimes complains.
 | 
				
			||||||
 | 
					    # @settings(suppress_health_check=[HealthCheck.too_slow])
 | 
				
			||||||
 | 
					    # def test_kitchen_sink(
 | 
				
			||||||
 | 
					    #         self, request_method, request_path, request_params, request_headers, request_body):
 | 
				
			||||||
 | 
					    #     """
 | 
				
			||||||
 | 
					    #     Throw everything at channels that we dare. The idea is that if a combination
 | 
				
			||||||
 | 
					    #     of method/path/headers/body would break the spec, hypothesis will eventually find it.
 | 
				
			||||||
 | 
					    #     """
 | 
				
			||||||
 | 
					    #     request_headers.append(content_length_header(request_body))
 | 
				
			||||||
 | 
					    #     message = message_for_request(
 | 
				
			||||||
 | 
					    #         request_method, request_path, request_params, request_headers, request_body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     self.assert_valid_http_request_message(
 | 
				
			||||||
 | 
					    #         message, request_method, request_path, request_params, request_headers, request_body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # def test_headers_are_lowercased_and_stripped(self):
 | 
				
			||||||
 | 
					    #     request_method, request_path = "GET", "/"
 | 
				
			||||||
 | 
					    #     headers = [("MYCUSTOMHEADER", "   foobar    ")]
 | 
				
			||||||
 | 
					    #     message = message_for_request(request_method, request_path, headers=headers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     self.assert_valid_http_request_message(
 | 
				
			||||||
 | 
					    #         message, request_method, request_path, request_headers=headers)
 | 
				
			||||||
 | 
					    #     # Note that Daphne returns a list of tuples here, which is fine, because the spec
 | 
				
			||||||
 | 
					    #     # asks to treat them interchangeably.
 | 
				
			||||||
 | 
					    #     assert message["headers"] == [(b"mycustomheader", b"foobar")]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # @given(daphne_path=http_strategies.http_path())
 | 
				
			||||||
 | 
					    # def test_root_path_header(self, daphne_path):
 | 
				
			||||||
 | 
					    #     """
 | 
				
			||||||
 | 
					    #     Tests root_path handling.
 | 
				
			||||||
 | 
					    #     """
 | 
				
			||||||
 | 
					    #     request_method, request_path = "GET", "/"
 | 
				
			||||||
 | 
					    #     # Daphne-Root-Path must be URL encoded when submitting as HTTP header field
 | 
				
			||||||
 | 
					    #     headers = [("Daphne-Root-Path", parse.quote(daphne_path.encode("utf8")))]
 | 
				
			||||||
 | 
					    #     message = message_for_request(request_method, request_path, headers=headers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     # Daphne-Root-Path is not included in the returned 'headers' section. So we expect
 | 
				
			||||||
 | 
					    #     # empty headers.
 | 
				
			||||||
 | 
					    #     expected_headers = []
 | 
				
			||||||
 | 
					    #     self.assert_valid_http_request_message(
 | 
				
			||||||
 | 
					    #         message, request_method, request_path, request_headers=expected_headers)
 | 
				
			||||||
 | 
					    #     # And what we're looking for, root_path being set.
 | 
				
			||||||
 | 
					    #     assert message["root_path"] == daphne_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# class TestProxyHandling(unittest.TestCase):
 | 
				
			||||||
 | 
					#     """
 | 
				
			||||||
 | 
					#     Tests that concern interaction of Daphne with proxies.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     They live in a separate test case, because they're not part of the spec.
 | 
				
			||||||
 | 
					#     """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     def setUp(self):
 | 
				
			||||||
 | 
					#         self.channel_layer = ChannelLayer()
 | 
				
			||||||
 | 
					#         self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
 | 
				
			||||||
 | 
					#         self.proto = self.factory.buildProtocol(("127.0.0.1", 0))
 | 
				
			||||||
 | 
					#         self.tr = proto_helpers.StringTransport()
 | 
				
			||||||
 | 
					#         self.proto.makeConnection(self.tr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     def test_x_forwarded_for_ignored(self):
 | 
				
			||||||
 | 
					#         self.proto.dataReceived(
 | 
				
			||||||
 | 
					#             b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
 | 
				
			||||||
 | 
					#             b"Host: somewhere.com\r\n" +
 | 
				
			||||||
 | 
					#             b"X-Forwarded-For: 10.1.2.3\r\n" +
 | 
				
			||||||
 | 
					#             b"X-Forwarded-Port: 80\r\n" +
 | 
				
			||||||
 | 
					#             b"\r\n"
 | 
				
			||||||
 | 
					#         )
 | 
				
			||||||
 | 
					#         # Get the resulting message off of the channel layer
 | 
				
			||||||
 | 
					#         _, message = self.channel_layer.receive(["http.request"])
 | 
				
			||||||
 | 
					#         self.assertEqual(message["client"], ["192.168.1.1", 54321])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     def test_x_forwarded_for_parsed(self):
 | 
				
			||||||
 | 
					#         self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
 | 
				
			||||||
 | 
					#         self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
 | 
				
			||||||
 | 
					#         self.proto.dataReceived(
 | 
				
			||||||
 | 
					#             b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
 | 
				
			||||||
 | 
					#             b"Host: somewhere.com\r\n" +
 | 
				
			||||||
 | 
					#             b"X-Forwarded-For: 10.1.2.3\r\n" +
 | 
				
			||||||
 | 
					#             b"X-Forwarded-Port: 80\r\n" +
 | 
				
			||||||
 | 
					#             b"\r\n"
 | 
				
			||||||
 | 
					#         )
 | 
				
			||||||
 | 
					#         # Get the resulting message off of the channel layer
 | 
				
			||||||
 | 
					#         _, message = self.channel_layer.receive(["http.request"])
 | 
				
			||||||
 | 
					#         self.assertEqual(message["client"], ["10.1.2.3", 80])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     def test_x_forwarded_for_port_missing(self):
 | 
				
			||||||
 | 
					#         self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
 | 
				
			||||||
 | 
					#         self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
 | 
				
			||||||
 | 
					#         self.proto.dataReceived(
 | 
				
			||||||
 | 
					#             b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
 | 
				
			||||||
 | 
					#             b"Host: somewhere.com\r\n" +
 | 
				
			||||||
 | 
					#             b"X-Forwarded-For: 10.1.2.3\r\n" +
 | 
				
			||||||
 | 
					#             b"\r\n"
 | 
				
			||||||
 | 
					#         )
 | 
				
			||||||
 | 
					#         # Get the resulting message off of the channel layer
 | 
				
			||||||
 | 
					#         _, message = self.channel_layer.receive(["http.request"])
 | 
				
			||||||
 | 
					#         self.assertEqual(message["client"], ["10.1.2.3", 0])
 | 
				
			||||||
| 
						 | 
					@ -1,11 +1,9 @@
 | 
				
			||||||
# coding: utf8
 | 
					# coding: utf8
 | 
				
			||||||
from __future__ import unicode_literals
 | 
					 | 
				
			||||||
from unittest import TestCase
 | 
					 | 
				
			||||||
import six
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from twisted.web.http_headers import Headers
 | 
					from twisted.web.http_headers import Headers
 | 
				
			||||||
 | 
					from unittest import TestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..utils import parse_x_forwarded_for
 | 
					from daphne.utils import parse_x_forwarded_for
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestXForwardedForHttpParsing(TestCase):
 | 
					class TestXForwardedForHttpParsing(TestCase):
 | 
				
			||||||
| 
						 | 
					@ -20,7 +18,7 @@ class TestXForwardedForHttpParsing(TestCase):
 | 
				
			||||||
        })
 | 
					        })
 | 
				
			||||||
        result = parse_x_forwarded_for(headers)
 | 
					        result = parse_x_forwarded_for(headers)
 | 
				
			||||||
        self.assertEqual(result, ["10.1.2.3", 1234])
 | 
					        self.assertEqual(result, ["10.1.2.3", 1234])
 | 
				
			||||||
        self.assertIsInstance(result[0], six.text_type)
 | 
					        self.assertIsInstance(result[0], str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_address_only(self):
 | 
					    def test_address_only(self):
 | 
				
			||||||
        headers = Headers({
 | 
					        headers = Headers({
 | 
				
			||||||
| 
						 | 
					@ -94,7 +92,7 @@ class TestXForwardedForWsParsing(TestCase):
 | 
				
			||||||
            ["1043::a321:0001", 0]
 | 
					            ["1043::a321:0001", 0]
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_multiple_proxys(self):
 | 
					    def test_multiple_proxies(self):
 | 
				
			||||||
        headers = {
 | 
					        headers = {
 | 
				
			||||||
            b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4",
 | 
					            b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4",
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user