From b72349d2c197850c4fd007e0dbaa749fc5d6305a Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Sat, 25 Nov 2017 18:23:54 -0800 Subject: [PATCH] HTTP protocol tests --- daphne/http_protocol.py | 7 +- daphne/server.py | 28 ++- daphne/test_utils.py | 86 +++++++ daphne/tests/test_http_request.py | 197 --------------- daphne/tests/testcases.py | 238 ------------------ setup.cfg | 9 - tests/http_base.py | 248 +++++++++++++++++++ {daphne/tests => tests}/http_strategies.py | 52 ++-- tests/test_http_request.py | 267 +++++++++++++++++++++ {daphne/tests => tests}/test_utils.py | 10 +- 10 files changed, 661 insertions(+), 481 deletions(-) create mode 100644 daphne/test_utils.py delete mode 100644 daphne/tests/test_http_request.py delete mode 100644 daphne/tests/testcases.py create mode 100644 tests/http_base.py rename {daphne/tests => tests}/http_strategies.py (72%) create mode 100644 tests/test_http_request.py rename {daphne/tests => tests}/test_utils.py (93%) diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index c3f4f4b..3170d24 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -127,12 +127,7 @@ class WebRequest(http.Request): # Remove our HTTP reply channel association logger.debug("Upgraded connection %s to WebSocket", self.client_addr) # Resume the producer so we keep getting data, if it's available as a method - # 17.1 version - if hasattr(self.channel, "_networkProducer"): - self.channel._networkProducer.resumeProducing() - # 16.x version - elif hasattr(self.channel, "resumeProducing"): - self.channel.resumeProducing() + self.channel._networkProducer.resumeProducing() # Boring old HTTP. else: diff --git a/daphne/server.py b/daphne/server.py index bf7e051..0892a7c 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -106,6 +106,12 @@ class Server(object): reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications) reactor.run(installSignalHandlers=self.signal_handlers) + def stop(self): + """ + Force-stops the server. + """ + reactor.stop() + ### Protocol handling def add_protocol(self, protocol): @@ -159,16 +165,20 @@ class Server(object): if application_instance.done(): exception = application_instance.exception() if exception: - logging.error( - "Exception inside application: {}\n{}{}".format( - exception, - "".join(traceback.format_tb( - exception.__traceback__, - )), - " {}".format(exception), + if isinstance(exception, KeyboardInterrupt): + # Protocol is asking the server to exit (likely during test) + self.stop() + else: + logging.error( + "Exception inside application: {}\n{}{}".format( + exception, + "".join(traceback.format_tb( + exception.__traceback__, + )), + " {}".format(exception), + ) ) - ) - protocol.handle_exception(exception) + protocol.handle_exception(exception) try: del self.application_instances[protocol] except KeyError: diff --git a/daphne/test_utils.py b/daphne/test_utils.py new file mode 100644 index 0000000..3829d81 --- /dev/null +++ b/daphne/test_utils.py @@ -0,0 +1,86 @@ +import msgpack +import os +import tempfile + + +class TestApplication: + """ + An application that receives one or more messages, sends a response, + and then quits the server. For testing. + """ + + setup_storage = os.path.join(tempfile.gettempdir(), "setup.testio") + result_storage = os.path.join(tempfile.gettempdir(), "result.testio") + + def __init__(self, scope): + self.scope = scope + self.messages = [] + + async def __call__(self, send, receive): + # Load setup info + setup = self.load_setup() + try: + for _ in range(setup["receive_messages"]): + self.messages.append(await receive()) + for message in setup["response_messages"]: + await send(message) + finally: + self.save_result() + + @classmethod + def save_setup(cls, response_messages, receive_messages=1): + """ + Stores setup information. + """ + with open(cls.setup_storage, "wb") as fh: + fh.write(msgpack.packb( + { + "response_messages": response_messages, + "receive_messages": receive_messages, + }, + use_bin_type=True, + )) + + @classmethod + def load_setup(cls): + """ + Returns setup details. + """ + with open(cls.setup_storage, "rb") as fh: + return msgpack.unpackb(fh.read(), encoding="utf-8") + + def save_result(self): + """ + Saves details of what happened to the result storage. + We could use pickle here, but that seems wrong, still, somehow. + """ + with open(self.result_storage, "wb") as fh: + fh.write(msgpack.packb( + { + "scope": self.scope, + "messages": self.messages, + }, + use_bin_type=True, + )) + + @classmethod + def load_result(cls): + """ + Returns result details. + """ + with open(cls.result_storage, "rb") as fh: + return msgpack.unpackb(fh.read(), encoding="utf-8") + + @classmethod + def clear_storage(cls): + """ + Clears storage files. + """ + try: + os.unlink(cls.setup_storage) + except OSError: + pass + try: + os.unlink(cls.result_storage) + except OSError: + pass diff --git a/daphne/tests/test_http_request.py b/daphne/tests/test_http_request.py deleted file mode 100644 index db19b7a..0000000 --- a/daphne/tests/test_http_request.py +++ /dev/null @@ -1,197 +0,0 @@ -# coding: utf8 -""" -Tests for the HTTP request section of the ASGI spec -""" -from __future__ import unicode_literals - -import unittest -from six.moves.urllib import parse - -from asgiref.inmemory import ChannelLayer -from hypothesis import given, assume, settings, HealthCheck -from twisted.test import proto_helpers - -from daphne.http_protocol import HTTPFactory -from daphne.tests import testcases, http_strategies -from daphne.tests.factories import message_for_request, content_length_header - - -class TestHTTPRequestSpec(testcases.ASGIHTTPTestCase): - """ - Tests which try to pour the HTTP request section of the ASGI spec into code. - The heavy lifting is done by the assert_valid_http_request_message function, - the tests mostly serve to wire up hypothesis so that it exercise it's power to find - edge cases. - """ - - def test_minimal_request(self): - """ - Smallest viable example. Mostly verifies that our request building works. - """ - request_method, request_path = "GET", "/" - message = message_for_request(request_method, request_path) - - self.assert_valid_http_request_message(message, request_method, request_path) - - @given( - request_path=http_strategies.http_path(), - request_params=http_strategies.query_params() - ) - def test_get_request(self, request_path, request_params): - """ - Tests a typical HTTP GET request, with a path and query parameters - """ - request_method = "GET" - message = message_for_request(request_method, request_path, request_params) - - self.assert_valid_http_request_message( - message, request_method, request_path, request_params=request_params) - - @given( - request_path=http_strategies.http_path(), - request_body=http_strategies.http_body() - ) - def test_post_request(self, request_path, request_body): - """ - Tests a typical POST request, submitting some data in a body. - """ - request_method = "POST" - headers = [content_length_header(request_body)] - message = message_for_request( - request_method, request_path, headers=headers, body=request_body) - - self.assert_valid_http_request_message( - message, request_method, request_path, - request_headers=headers, request_body=request_body) - - @given(request_headers=http_strategies.headers()) - def test_headers(self, request_headers): - """ - Tests that HTTP header fields are handled as specified - """ - request_method, request_path = "OPTIONS", "/te st-à/" - message = message_for_request(request_method, request_path, headers=request_headers) - - self.assert_valid_http_request_message( - message, request_method, request_path, request_headers=request_headers) - - @given(request_headers=http_strategies.headers()) - def test_duplicate_headers(self, request_headers): - """ - Tests that duplicate header values are preserved - """ - assume(len(request_headers) >= 2) - # Set all header field names to the same value - header_name = request_headers[0][0] - duplicated_headers = [(header_name, header[1]) for header in request_headers] - - request_method, request_path = "OPTIONS", "/te st-à/" - message = message_for_request(request_method, request_path, headers=duplicated_headers) - - self.assert_valid_http_request_message( - message, request_method, request_path, request_headers=duplicated_headers) - - @given( - request_method=http_strategies.http_method(), - request_path=http_strategies.http_path(), - request_params=http_strategies.query_params(), - request_headers=http_strategies.headers(), - request_body=http_strategies.http_body(), - ) - # This test is slow enough that on Travis, hypothesis sometimes complains. - @settings(suppress_health_check=[HealthCheck.too_slow]) - def test_kitchen_sink( - self, request_method, request_path, request_params, request_headers, request_body): - """ - Throw everything at channels that we dare. The idea is that if a combination - of method/path/headers/body would break the spec, hypothesis will eventually find it. - """ - request_headers.append(content_length_header(request_body)) - message = message_for_request( - request_method, request_path, request_params, request_headers, request_body) - - self.assert_valid_http_request_message( - message, request_method, request_path, request_params, request_headers, request_body) - - def test_headers_are_lowercased_and_stripped(self): - request_method, request_path = "GET", "/" - headers = [("MYCUSTOMHEADER", " foobar ")] - message = message_for_request(request_method, request_path, headers=headers) - - self.assert_valid_http_request_message( - message, request_method, request_path, request_headers=headers) - # Note that Daphne returns a list of tuples here, which is fine, because the spec - # asks to treat them interchangeably. - assert message["headers"] == [(b"mycustomheader", b"foobar")] - - @given(daphne_path=http_strategies.http_path()) - def test_root_path_header(self, daphne_path): - """ - Tests root_path handling. - """ - request_method, request_path = "GET", "/" - # Daphne-Root-Path must be URL encoded when submitting as HTTP header field - headers = [("Daphne-Root-Path", parse.quote(daphne_path.encode("utf8")))] - message = message_for_request(request_method, request_path, headers=headers) - - # Daphne-Root-Path is not included in the returned 'headers' section. So we expect - # empty headers. - expected_headers = [] - self.assert_valid_http_request_message( - message, request_method, request_path, request_headers=expected_headers) - # And what we're looking for, root_path being set. - assert message["root_path"] == daphne_path - - -class TestProxyHandling(unittest.TestCase): - """ - Tests that concern interaction of Daphne with proxies. - - They live in a separate test case, because they're not part of the spec. - """ - - def setUp(self): - self.channel_layer = ChannelLayer() - self.factory = HTTPFactory(self.channel_layer, send_channel="test!") - self.proto = self.factory.buildProtocol(("127.0.0.1", 0)) - self.tr = proto_helpers.StringTransport() - self.proto.makeConnection(self.tr) - - def test_x_forwarded_for_ignored(self): - self.proto.dataReceived( - b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + - b"Host: somewhere.com\r\n" + - b"X-Forwarded-For: 10.1.2.3\r\n" + - b"X-Forwarded-Port: 80\r\n" + - b"\r\n" - ) - # Get the resulting message off of the channel layer - _, message = self.channel_layer.receive(["http.request"]) - self.assertEqual(message["client"], ["192.168.1.1", 54321]) - - def test_x_forwarded_for_parsed(self): - self.factory.proxy_forwarded_address_header = "X-Forwarded-For" - self.factory.proxy_forwarded_port_header = "X-Forwarded-Port" - self.proto.dataReceived( - b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + - b"Host: somewhere.com\r\n" + - b"X-Forwarded-For: 10.1.2.3\r\n" + - b"X-Forwarded-Port: 80\r\n" + - b"\r\n" - ) - # Get the resulting message off of the channel layer - _, message = self.channel_layer.receive(["http.request"]) - self.assertEqual(message["client"], ["10.1.2.3", 80]) - - def test_x_forwarded_for_port_missing(self): - self.factory.proxy_forwarded_address_header = "X-Forwarded-For" - self.factory.proxy_forwarded_port_header = "X-Forwarded-Port" - self.proto.dataReceived( - b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + - b"Host: somewhere.com\r\n" + - b"X-Forwarded-For: 10.1.2.3\r\n" + - b"\r\n" - ) - # Get the resulting message off of the channel layer - _, message = self.channel_layer.receive(["http.request"]) - self.assertEqual(message["client"], ["10.1.2.3", 0]) diff --git a/daphne/tests/testcases.py b/daphne/tests/testcases.py deleted file mode 100644 index a31cb60..0000000 --- a/daphne/tests/testcases.py +++ /dev/null @@ -1,238 +0,0 @@ -""" -Contains a test case class to allow verifying ASGI messages -""" -from __future__ import unicode_literals - -from collections import defaultdict -from urllib import parse -import socket -import unittest - -from . import factories - - -class ASGITestCaseBase(unittest.TestCase): - """ - Base class for our test classes which contains shared method. - """ - - def assert_is_ip_address(self, address): - """ - Tests whether a given address string is a valid IPv4 or IPv6 address. - """ - try: - socket.inet_aton(address) - except socket.error: - self.fail("'%s' is not a valid IP address." % address) - - def assert_presence_of_message_keys(self, keys, required_keys, optional_keys): - present_keys = set(keys) - self.assertTrue(required_keys <= present_keys) - # Assert that no other keys are present - self.assertEqual(set(), present_keys - required_keys - optional_keys) - - def assert_valid_reply_channel(self, reply_channel): - self.assertIsInstance(reply_channel, str) - # The reply channel is decided by the server. - self.assertTrue(reply_channel.startswith("test!")) - - def assert_valid_path(self, path, request_path): - self.assertIsInstance(path, str) - self.assertEqual(path, request_path) - # Assert that it's already url decoded - self.assertEqual(path, parse.unquote(path)) - - def assert_valid_address_and_port(self, host): - address, port = host - self.assertIsInstance(address, str) - self.assert_is_ip_address(address) - self.assertIsInstance(port, int) - - -class ASGIHTTPTestCase(ASGITestCaseBase): - """ - Test case with helpers for verifying HTTP channel messages - """ - - def assert_valid_http_request_message( - self, channel_message, request_method, request_path, - request_params=None, request_headers=None, request_body=None): - """ - Asserts that a given channel message conforms to the HTTP request section of the ASGI spec. - """ - - self.assertTrue(channel_message) - - self.assert_presence_of_message_keys( - channel_message.keys(), - {"reply_channel", "http_version", "method", "path", "query_string", "headers"}, - {"scheme", "root_path", "body", "body_channel", "client", "server"}) - - # == Assertions about required channel_message fields == - self.assert_valid_reply_channel(channel_message["reply_channel"]) - self.assert_valid_path(channel_message["path"], request_path) - - http_version = channel_message["http_version"] - self.assertIsInstance(http_version, str) - self.assertIn(http_version, ["1.0", "1.1", "1.2"]) - - method = channel_message["method"] - self.assertIsInstance(method, str) - self.assertTrue(method.isupper()) - self.assertEqual(channel_message["method"], request_method) - - query_string = channel_message["query_string"] - # Assert that query_string is a byte string and still url encoded - self.assertIsInstance(query_string, bytes) - self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii")) - - # Ordering of header names is not important, but the order of values for a header - # name is. To assert whether that order is kept, we transform both the request - # headers and the channel message headers into a dictionary - # {name: [value1, value2, ...]} and check if they're equal. - transformed_message_headers = defaultdict(list) - for name, value in channel_message["headers"]: - transformed_message_headers[name].append(value) - - transformed_request_headers = defaultdict(list) - for name, value in (request_headers or []): - expected_name = name.lower().strip().encode("ascii") - expected_value = value.strip().encode("ascii") - transformed_request_headers[expected_name].append(expected_value) - - self.assertEqual(transformed_message_headers, transformed_request_headers) - - # == Assertions about optional channel_message fields == - - scheme = channel_message.get("scheme") - if scheme is not None: - self.assertIsInstance(scheme, str) - self.assertTrue(scheme) # May not be empty - - root_path = channel_message.get("root_path") - if root_path is not None: - self.assertIsInstance(root_path, str) - - body = channel_message.get("body") - # Ensure we test for presence of 'body' if a request body was given - if request_body is not None or body is not None: - self.assertIsInstance(body, str) - self.assertEqual(body, (request_body or "").encode("ascii")) - - body_channel = channel_message.get("body_channel") - if body_channel is not None: - self.assertIsInstance(body_channel, str) - self.assertIn("?", body_channel) - - client = channel_message.get("client") - if client is not None: - self.assert_valid_address_and_port(channel_message["client"]) - - server = channel_message.get("server") - if server is not None: - self.assert_valid_address_and_port(channel_message["server"]) - - def assert_valid_http_response_message(self, message, response): - self.assertTrue(message) - self.assertTrue(response.startswith(b"HTTP")) - - status_code_bytes = str(message["status"]).encode("ascii") - self.assertIn(status_code_bytes, response) - - if "content" in message: - self.assertIn(message["content"], response) - - # Check that headers are in the given order. - # N.b. HTTP spec only enforces that the order of header values is kept, but - # the ASGI spec requires that order of all headers is kept. This code - # checks conformance with the stricter ASGI spec. - if "headers" in message: - for name, value in message["headers"]: - expected_header = factories.header_line(name, value) - # Daphne or Twisted turn our lower cased header names ('foo-bar') into title - # case ('Foo-Bar'). So technically we want to to match that the header name is - # present while ignoring casing, and want to ensure the value is present without - # altered casing. The approach below does this well enough. - self.assertIn(expected_header.lower(), response.lower()) - self.assertIn(value.encode("ascii"), response) - - -class ASGIWebSocketTestCase(ASGITestCaseBase): - """ - Test case with helpers for verifying WebSocket channel messages - """ - - def assert_websocket_upgrade(self, response, body=b"", expect_close=False): - self.assertIn(b"HTTP/1.1 101 Switching Protocols", response) - self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response) - self.assertIn(body, response) - self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8")) - - def assert_websocket_denied(self, response): - self.assertIn(b"HTTP/1.1 403", response) - - def assert_valid_websocket_connect_message( - self, channel_message, request_path="/", request_params=None, request_headers=None): - """ - Asserts that a given channel message conforms to the HTTP request section of the ASGI spec. - """ - - self.assertTrue(channel_message) - - self.assert_presence_of_message_keys( - channel_message.keys(), - {"reply_channel", "path", "headers", "order"}, - {"scheme", "query_string", "root_path", "client", "server"}) - - # == Assertions about required channel_message fields == - self.assert_valid_reply_channel(channel_message["reply_channel"]) - self.assert_valid_path(channel_message["path"], request_path) - - order = channel_message["order"] - self.assertIsInstance(order, int) - self.assertEqual(order, 0) - - # Ordering of header names is not important, but the order of values for a header - # name is. To assert whether that order is kept, we transform the request - # headers and the channel message headers into a set - # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal. - # Note that unlike for HTTP, Daphne never gives out individual header values; instead we - # get one string per header field with values separated by comma. - transformed_request_headers = defaultdict(list) - for name, value in (request_headers or []): - expected_name = name.lower().strip().encode("ascii") - expected_value = value.strip().encode("ascii") - transformed_request_headers[expected_name].append(expected_value) - final_request_headers = { - (name, b",".join(value)) for name, value in transformed_request_headers.items() - } - - # Websockets carry a lot of additional header fields, so instead of verifying that - # headers look exactly like expected, we just check that the expected header fields - # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed - # and not tested for. - assert final_request_headers.issubset(set(channel_message["headers"])) - - # == Assertions about optional channel_message fields == - scheme = channel_message.get("scheme") - if scheme: - self.assertIsInstance(scheme, six.text_type) - self.assertIn(scheme, ["ws", "wss"]) - - query_string = channel_message.get("query_string") - if query_string: - # Assert that query_string is a byte string and still url encoded - self.assertIsInstance(query_string, six.binary_type) - self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii")) - - root_path = channel_message.get("root_path") - if root_path is not None: - self.assertIsInstance(root_path, six.text_type) - - client = channel_message.get("client") - if client is not None: - self.assert_valid_address_and_port(channel_message["client"]) - - server = channel_message.get("server") - if server is not None: - self.assert_valid_address_and_port(channel_message["server"]) diff --git a/setup.cfg b/setup.cfg index b436afc..d0ede1c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,12 +3,3 @@ universal=1 [tool:pytest] addopts = tests/ - -[yapf] -based_on_style = pep8 -column_limit = 120 -join_multiple_lines = false -split_arguments_when_comma_terminated = true -split_before_expression_after_opening_paren = true -split_before_first_argument = true -split_penalty_after_opening_bracket = -10 diff --git a/tests/http_base.py b/tests/http_base.py new file mode 100644 index 0000000..36f2c22 --- /dev/null +++ b/tests/http_base.py @@ -0,0 +1,248 @@ +from urllib import parse +import requests +import socket +import subprocess +import time +import unittest + +from daphne.test_utils import TestApplication + + +class DaphneTestCase(unittest.TestCase): + """ + Base class for Daphne integration test cases. + + Boots up a copy of Daphne on a test port and sends it a request, and + retrieves the response. Uses a custom ASGI application and temporary files + to store/retrieve the request/response messages. + """ + + def port_in_use(self, port): + """ + Tests if a port is in use on the local machine. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.bind(("127.0.0.1", port)) + except socket.error as e: + if e.errno in [13, 98]: + return True + else: + raise + else: + return False + finally: + s.close() + + def run_daphne(self, method, path, params, data, responses, headers=None, timeout=1): + """ + Runs Daphne with the given request callback (given the base URL) + and response messages. + """ + # Store setup info + TestApplication.clear_storage() + TestApplication.save_setup( + response_messages=responses, + ) + # Find a free port + for i in range(11200, 11300): + if not self.port_in_use(i): + port = i + break + else: + raise RuntimeError("Cannot find a free port to test on") + # Launch daphne on that port + process = subprocess.Popen(["daphne", "-p", str(port), "daphne.test_utils:TestApplication"]) + try: + for _ in range(100): + time.sleep(0.1) + if self.port_in_use(port): + break + else: + raise RuntimeError("Daphne never came up.") + # Send it the request + url = "http://127.0.0.1:%i%s" % (port, path) + response = requests.request(method, url, params=params, data=data, headers=headers, timeout=timeout) + finally: + # Shut down daphne + process.terminate() + # Load the information + inner_result = TestApplication.load_result() + # Return the inner result and the response + return inner_result, response + + def run_daphne_request(self, method, path, params=None, data=None, headers=None): + """ + Convenience method for just testing request handling. + Returns (scope, messages) + """ + if headers is not None: + headers = dict(headers) + inner_result, _ = self.run_daphne( + method=method, + path=path, + params=params, + data=data, + headers=headers, + responses=[{"type": "http.response", "status": 200, "content": b"OK"}], + ) + return inner_result["scope"], inner_result["messages"] + + def tearDown(self): + """ + Ensures any storage files are cleared. + """ + TestApplication.clear_storage() + + def assert_is_ip_address(self, address): + """ + Tests whether a given address string is a valid IPv4 or IPv6 address. + """ + try: + socket.inet_aton(address) + except socket.error: + self.fail("'%s' is not a valid IP address." % address) + + def assert_key_sets(self, required_keys, optional_keys, actual_keys): + """ + Asserts that all required_keys are in actual_keys, and that there + are no keys in actual_keys that aren't required or optional. + """ + present_keys = set(actual_keys) + # Make sure all required keys are present + self.assertTrue(required_keys <= present_keys) + # Assert that no other keys are present + self.assertEqual( + set(), + present_keys - required_keys - optional_keys, + ) + + def assert_valid_path(self, path, request_path): + """ + Checks the path is valid and already url-decoded. + """ + self.assertIsInstance(path, str) + self.assertEqual(path, request_path) + # Assert that it's already url decoded + self.assertEqual(path, parse.unquote(path)) + + def assert_valid_address_and_port(self, host): + """ + Asserts the value is a valid (host, port) tuple. + """ + address, port = host + self.assertIsInstance(address, str) + self.assert_is_ip_address(address) + self.assertIsInstance(port, int) + + +# class ASGIHTTPTestCase(ASGITestCaseBase): +# """ +# Test case with helpers for verifying HTTP channel messages +# """ + + +# def assert_valid_http_response_message(self, message, response): +# self.assertTrue(message) +# self.assertTrue(response.startswith(b"HTTP")) + +# status_code_bytes = str(message["status"]).encode("ascii") +# self.assertIn(status_code_bytes, response) + +# if "content" in message: +# self.assertIn(message["content"], response) + +# # Check that headers are in the given order. +# # N.b. HTTP spec only enforces that the order of header values is kept, but +# # the ASGI spec requires that order of all headers is kept. This code +# # checks conformance with the stricter ASGI spec. +# if "headers" in message: +# for name, value in message["headers"]: +# expected_header = factories.header_line(name, value) +# # Daphne or Twisted turn our lower cased header names ('foo-bar') into title +# # case ('Foo-Bar'). So technically we want to to match that the header name is +# # present while ignoring casing, and want to ensure the value is present without +# # altered casing. The approach below does this well enough. +# self.assertIn(expected_header.lower(), response.lower()) +# self.assertIn(value.encode("ascii"), response) + + +# class ASGIWebSocketTestCase(ASGITestCaseBase): +# """ +# Test case with helpers for verifying WebSocket channel messages +# """ + +# def assert_websocket_upgrade(self, response, body=b"", expect_close=False): +# self.assertIn(b"HTTP/1.1 101 Switching Protocols", response) +# self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response) +# self.assertIn(body, response) +# self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8")) + +# def assert_websocket_denied(self, response): +# self.assertIn(b"HTTP/1.1 403", response) + +# def assert_valid_websocket_connect_message( +# self, channel_message, request_path="/", request_params=None, request_headers=None): +# """ +# Asserts that a given channel message conforms to the HTTP request section of the ASGI spec. +# """ + +# self.assertTrue(channel_message) + +# self.assert_presence_of_message_keys( +# channel_message.keys(), +# {"reply_channel", "path", "headers", "order"}, +# {"scheme", "query_string", "root_path", "client", "server"}) + +# # == Assertions about required channel_message fields == +# self.assert_valid_reply_channel(channel_message["reply_channel"]) +# self.assert_valid_path(channel_message["path"], request_path) + +# order = channel_message["order"] +# self.assertIsInstance(order, int) +# self.assertEqual(order, 0) + +# # Ordering of header names is not important, but the order of values for a header +# # name is. To assert whether that order is kept, we transform the request +# # headers and the channel message headers into a set +# # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal. +# # Note that unlike for HTTP, Daphne never gives out individual header values; instead we +# # get one string per header field with values separated by comma. +# transformed_request_headers = defaultdict(list) +# for name, value in (request_headers or []): +# expected_name = name.lower().strip().encode("ascii") +# expected_value = value.strip().encode("ascii") +# transformed_request_headers[expected_name].append(expected_value) +# final_request_headers = { +# (name, b",".join(value)) for name, value in transformed_request_headers.items() +# } + +# # Websockets carry a lot of additional header fields, so instead of verifying that +# # headers look exactly like expected, we just check that the expected header fields +# # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed +# # and not tested for. +# assert final_request_headers.issubset(set(channel_message["headers"])) + +# # == Assertions about optional channel_message fields == +# scheme = channel_message.get("scheme") +# if scheme: +# self.assertIsInstance(scheme, six.text_type) +# self.assertIn(scheme, ["ws", "wss"]) + +# query_string = channel_message.get("query_string") +# if query_string: +# # Assert that query_string is a byte string and still url encoded +# self.assertIsInstance(query_string, six.binary_type) +# self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii")) + +# root_path = channel_message.get("root_path") +# if root_path is not None: +# self.assertIsInstance(root_path, six.text_type) + +# client = channel_message.get("client") +# if client is not None: +# self.assert_valid_address_and_port(channel_message["client"]) + +# server = channel_message.get("server") +# if server is not None: +# self.assert_valid_address_and_port(channel_message["server"]) diff --git a/daphne/tests/http_strategies.py b/tests/http_strategies.py similarity index 72% rename from daphne/tests/http_strategies.py rename to tests/http_strategies.py index dc26100..e4b3d9c 100644 --- a/daphne/tests/http_strategies.py +++ b/tests/http_strategies.py @@ -27,18 +27,16 @@ def http_path(): Returns a URL path (not encoded). """ return strategies.lists( - _http_path_portion(), min_size=0, max_size=10).map(lambda s: "/" + "/".join(s)) + _http_path_portion(), + min_size=0, + max_size=10, + ).map(lambda s: "/" + "/".join(s)) def http_body(): """ - Returns random printable ASCII characters. This may be exceeding what HTTP allows, - but seems to not cause an issue so far. + Returns random binary body data. """ - return strategies.text(alphabet=string.printable, min_size=0, average_size=600, max_size=1500) - - -def binary_payload(): return strategies.binary(min_size=0, average_size=600, max_size=1500) @@ -59,7 +57,11 @@ def valid_bidi(value): def _domain_label(): return strategies.text( - alphabet=letters, min_size=1, average_size=6, max_size=63).filter(valid_bidi) + alphabet=letters, + min_size=1, + average_size=6, + max_size=63, + ).filter(valid_bidi) def international_domain_name(): @@ -67,12 +69,19 @@ def international_domain_name(): Returns a byte string of a domain name, IDNA-encoded. """ return strategies.lists( - _domain_label(), min_size=2, average_size=2).map(lambda s: (".".join(s)).encode("idna")) + _domain_label(), + min_size=2, + average_size=2, + ).map(lambda s: (".".join(s)).encode("idna")) def _query_param(): - return strategies.text(alphabet=letters, min_size=1, average_size=10, max_size=255).\ - map(lambda s: s.encode("utf8")) + return strategies.text( + alphabet=letters, + min_size=1, + average_size=10, + max_size=255, + ).map(lambda s: s.encode("utf8")) def query_params(): @@ -82,8 +91,10 @@ def query_params(): ensures that the total urlencoded query string is not longer than 1500 characters. """ return strategies.lists( - strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5).\ - filter(lambda x: len(parse.urlencode(x)) < 1500) + strategies.tuples(_query_param(), _query_param()), + min_size=0, + average_size=5, + ).filter(lambda x: len(parse.urlencode(x)) < 1500) def header_name(): @@ -94,7 +105,10 @@ def header_name(): and 20 characters long """ return strategies.text( - alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30) + alphabet=string.ascii_letters + string.digits + "-", + min_size=1, + max_size=30, + ) def header_value(): @@ -106,7 +120,10 @@ def header_value(): """ return strategies.text( alphabet=string.ascii_letters + string.digits + string.punctuation + " /t", - min_size=1, average_size=40, max_size=8190).filter(lambda s: len(s.encode("utf8")) < 8190) + min_size=1, + average_size=40, + max_size=8190, + ).filter(lambda s: len(s.encode("utf8")) < 8190) def headers(): @@ -118,4 +135,7 @@ def headers(): """ return strategies.lists( strategies.tuples(header_name(), header_value()), - min_size=0, average_size=10, max_size=100) + min_size=0, + average_size=10, + max_size=100, + ) diff --git a/tests/test_http_request.py b/tests/test_http_request.py new file mode 100644 index 0000000..4ea6b65 --- /dev/null +++ b/tests/test_http_request.py @@ -0,0 +1,267 @@ +# coding: utf8 + +import collections +from urllib import parse + +from hypothesis import given, assume, settings, HealthCheck + +import http_strategies +from http_base import DaphneTestCase + + +class TestHTTPRequestSpec(DaphneTestCase): + """ + Tests which try to pour the HTTP request section of the ASGI spec into code. + The heavy lifting is done by the assert_valid_http_request_message function, + the tests mostly serve to wire up hypothesis so that it exercise it's power to find + edge cases. + """ + + def assert_valid_http_scope( + self, + scope, + method, + path, + params=None, + headers=None, + scheme=None, + ): + """ + Checks that the passed scope is a valid ASGI HTTP scope regarding types + and some urlencoding things. + """ + # Check overall keys + self.assert_key_sets( + required_keys={"type", "http_version", "method", "path", "query_string", "headers"}, + optional_keys={"scheme", "root_path", "client", "server"}, + actual_keys=scope.keys(), + ) + # Check that it is the right type + self.assertEqual(scope["type"], "http") + # Method (uppercased unicode string) + self.assertIsInstance(scope["method"], str) + self.assertEqual(scope["method"], method.upper()) + # Path + self.assert_valid_path(scope["path"], path) + # HTTP version + self.assertIn(scope["http_version"], ["1.0", "1.1", "1.2"]) + # Scheme + self.assertIn(scope["scheme"], ["http", "https"]) + if scheme: + self.assertEqual(scheme, scope["scheme"]) + # Query string (byte string and still url encoded) + query_string = scope["query_string"] + self.assertIsInstance(query_string, bytes) + if params: + self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii")) + # Ordering of header names is not important, but the order of values for a header + # name is. To assert whether that order is kept, we transform both the request + # headers and the channel message headers into a dictionary + # {name: [value1, value2, ...]} and check if they're equal. + transformed_scope_headers = collections.defaultdict(list) + for name, value in scope["headers"]: + transformed_scope_headers[name].append(value) + transformed_request_headers = collections.defaultdict(list) + for name, value in (headers or []): + expected_name = name.lower().strip().encode("ascii") + expected_value = value.strip().encode("ascii") + transformed_request_headers[expected_name].append(expected_value) + for name, value in transformed_request_headers.items(): + self.assertIn(name, transformed_scope_headers) + self.assertEqual(value, transformed_scope_headers[name]) + # Root path + self.assertIsInstance(scope.get("root_path", ""), str) + # Client and server addresses + client = scope.get("client") + if client is not None: + self.assert_valid_address_and_port(client) + server = scope.get("server") + if server is not None: + self.assert_valid_address_and_port(server) + + def assert_valid_http_request_message( + self, + message, + body=None, + ): + """ + Asserts that a message is a valid http.request message + """ + # Check overall keys + self.assert_key_sets( + required_keys={"type"}, + optional_keys={"body", "more_content"}, + actual_keys=message.keys(), + ) + # Check that it is the right type + self.assertEqual(message["type"], "http.request") + # If there's a body present, check its type + self.assertIsInstance(message.get("body", b""), bytes) + if body is not None: + self.assertEqual(body, message.get("body", b"")) + + def test_minimal_request(self): + """ + Smallest viable example. Mostly verifies that our request building works. + """ + scope, messages = self.run_daphne_request("GET", "/") + self.assert_valid_http_scope(scope, "GET", "/") + self.assert_valid_http_request_message(messages[0], body=b"") + + @given( + request_path=http_strategies.http_path(), + request_params=http_strategies.query_params() + ) + @settings(max_examples=5, deadline=2000) + def test_get_request(self, request_path, request_params): + """ + Tests a typical HTTP GET request, with a path and query parameters + """ + scope, messages = self.run_daphne_request("GET", request_path, params=request_params) + self.assert_valid_http_scope(scope, "GET", request_path, params=request_params) + self.assert_valid_http_request_message(messages[0], body=b"") + + @given( + request_path=http_strategies.http_path(), + request_body=http_strategies.http_body() + ) + @settings(max_examples=5, deadline=2000) + def test_post_request(self, request_path, request_body): + """ + Tests a typical HTTP POST request, with a path and body. + """ + scope, messages = self.run_daphne_request("POST", request_path, data=request_body) + self.assert_valid_http_scope(scope, "POST", request_path) + self.assert_valid_http_request_message(messages[0], body=request_body) + + @given(request_headers=http_strategies.headers()) + @settings(max_examples=5, deadline=2000) + def test_headers(self, request_headers): + """ + Tests that HTTP header fields are handled as specified + """ + request_path = "/te st-à/" + scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers) + self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers) + self.assert_valid_http_request_message(messages[0], body=b"") + + # @given(request_headers=http_strategies.headers()) + # def test_duplicate_headers(self, request_headers): + # """ + # Tests that duplicate header values are preserved + # """ + # assume(len(request_headers) >= 2) + # # Set all header field names to the same value + # header_name = request_headers[0][0] + # duplicated_headers = [(header_name, header[1]) for header in request_headers] + + # request_method, request_path = "OPTIONS", "/te st-à/" + # message = message_for_request(request_method, request_path, headers=duplicated_headers) + + # self.assert_valid_http_request_message( + # message, request_method, request_path, request_headers=duplicated_headers) + + # @given( + # request_method=http_strategies.http_method(), + # request_path=http_strategies.http_path(), + # request_params=http_strategies.query_params(), + # request_headers=http_strategies.headers(), + # request_body=http_strategies.http_body(), + # ) + # # This test is slow enough that on Travis, hypothesis sometimes complains. + # @settings(suppress_health_check=[HealthCheck.too_slow]) + # def test_kitchen_sink( + # self, request_method, request_path, request_params, request_headers, request_body): + # """ + # Throw everything at channels that we dare. The idea is that if a combination + # of method/path/headers/body would break the spec, hypothesis will eventually find it. + # """ + # request_headers.append(content_length_header(request_body)) + # message = message_for_request( + # request_method, request_path, request_params, request_headers, request_body) + + # self.assert_valid_http_request_message( + # message, request_method, request_path, request_params, request_headers, request_body) + + # def test_headers_are_lowercased_and_stripped(self): + # request_method, request_path = "GET", "/" + # headers = [("MYCUSTOMHEADER", " foobar ")] + # message = message_for_request(request_method, request_path, headers=headers) + + # self.assert_valid_http_request_message( + # message, request_method, request_path, request_headers=headers) + # # Note that Daphne returns a list of tuples here, which is fine, because the spec + # # asks to treat them interchangeably. + # assert message["headers"] == [(b"mycustomheader", b"foobar")] + + # @given(daphne_path=http_strategies.http_path()) + # def test_root_path_header(self, daphne_path): + # """ + # Tests root_path handling. + # """ + # request_method, request_path = "GET", "/" + # # Daphne-Root-Path must be URL encoded when submitting as HTTP header field + # headers = [("Daphne-Root-Path", parse.quote(daphne_path.encode("utf8")))] + # message = message_for_request(request_method, request_path, headers=headers) + + # # Daphne-Root-Path is not included in the returned 'headers' section. So we expect + # # empty headers. + # expected_headers = [] + # self.assert_valid_http_request_message( + # message, request_method, request_path, request_headers=expected_headers) + # # And what we're looking for, root_path being set. + # assert message["root_path"] == daphne_path + + +# class TestProxyHandling(unittest.TestCase): +# """ +# Tests that concern interaction of Daphne with proxies. + +# They live in a separate test case, because they're not part of the spec. +# """ + +# def setUp(self): +# self.channel_layer = ChannelLayer() +# self.factory = HTTPFactory(self.channel_layer, send_channel="test!") +# self.proto = self.factory.buildProtocol(("127.0.0.1", 0)) +# self.tr = proto_helpers.StringTransport() +# self.proto.makeConnection(self.tr) + +# def test_x_forwarded_for_ignored(self): +# self.proto.dataReceived( +# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + +# b"Host: somewhere.com\r\n" + +# b"X-Forwarded-For: 10.1.2.3\r\n" + +# b"X-Forwarded-Port: 80\r\n" + +# b"\r\n" +# ) +# # Get the resulting message off of the channel layer +# _, message = self.channel_layer.receive(["http.request"]) +# self.assertEqual(message["client"], ["192.168.1.1", 54321]) + +# def test_x_forwarded_for_parsed(self): +# self.factory.proxy_forwarded_address_header = "X-Forwarded-For" +# self.factory.proxy_forwarded_port_header = "X-Forwarded-Port" +# self.proto.dataReceived( +# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + +# b"Host: somewhere.com\r\n" + +# b"X-Forwarded-For: 10.1.2.3\r\n" + +# b"X-Forwarded-Port: 80\r\n" + +# b"\r\n" +# ) +# # Get the resulting message off of the channel layer +# _, message = self.channel_layer.receive(["http.request"]) +# self.assertEqual(message["client"], ["10.1.2.3", 80]) + +# def test_x_forwarded_for_port_missing(self): +# self.factory.proxy_forwarded_address_header = "X-Forwarded-For" +# self.factory.proxy_forwarded_port_header = "X-Forwarded-Port" +# self.proto.dataReceived( +# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + +# b"Host: somewhere.com\r\n" + +# b"X-Forwarded-For: 10.1.2.3\r\n" + +# b"\r\n" +# ) +# # Get the resulting message off of the channel layer +# _, message = self.channel_layer.receive(["http.request"]) +# self.assertEqual(message["client"], ["10.1.2.3", 0]) diff --git a/daphne/tests/test_utils.py b/tests/test_utils.py similarity index 93% rename from daphne/tests/test_utils.py rename to tests/test_utils.py index fcac783..5f0b364 100644 --- a/daphne/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,9 @@ # coding: utf8 -from __future__ import unicode_literals -from unittest import TestCase -import six from twisted.web.http_headers import Headers +from unittest import TestCase -from ..utils import parse_x_forwarded_for +from daphne.utils import parse_x_forwarded_for class TestXForwardedForHttpParsing(TestCase): @@ -20,7 +18,7 @@ class TestXForwardedForHttpParsing(TestCase): }) result = parse_x_forwarded_for(headers) self.assertEqual(result, ["10.1.2.3", 1234]) - self.assertIsInstance(result[0], six.text_type) + self.assertIsInstance(result[0], str) def test_address_only(self): headers = Headers({ @@ -94,7 +92,7 @@ class TestXForwardedForWsParsing(TestCase): ["1043::a321:0001", 0] ) - def test_multiple_proxys(self): + def test_multiple_proxies(self): headers = { b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4", }