HTTP protocol tests

This commit is contained in:
Andrew Godwin 2017-11-25 18:23:54 -08:00
parent 0626f39214
commit b72349d2c1
10 changed files with 661 additions and 481 deletions

View File

@ -127,12 +127,7 @@ class WebRequest(http.Request):
# Remove our HTTP reply channel association # Remove our HTTP reply channel association
logger.debug("Upgraded connection %s to WebSocket", self.client_addr) logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
# Resume the producer so we keep getting data, if it's available as a method # Resume the producer so we keep getting data, if it's available as a method
# 17.1 version self.channel._networkProducer.resumeProducing()
if hasattr(self.channel, "_networkProducer"):
self.channel._networkProducer.resumeProducing()
# 16.x version
elif hasattr(self.channel, "resumeProducing"):
self.channel.resumeProducing()
# Boring old HTTP. # Boring old HTTP.
else: else:

View File

@ -106,6 +106,12 @@ class Server(object):
reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications) reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications)
reactor.run(installSignalHandlers=self.signal_handlers) reactor.run(installSignalHandlers=self.signal_handlers)
def stop(self):
"""
Force-stops the server.
"""
reactor.stop()
### Protocol handling ### Protocol handling
def add_protocol(self, protocol): def add_protocol(self, protocol):
@ -159,16 +165,20 @@ class Server(object):
if application_instance.done(): if application_instance.done():
exception = application_instance.exception() exception = application_instance.exception()
if exception: if exception:
logging.error( if isinstance(exception, KeyboardInterrupt):
"Exception inside application: {}\n{}{}".format( # Protocol is asking the server to exit (likely during test)
exception, self.stop()
"".join(traceback.format_tb( else:
exception.__traceback__, logging.error(
)), "Exception inside application: {}\n{}{}".format(
" {}".format(exception), exception,
"".join(traceback.format_tb(
exception.__traceback__,
)),
" {}".format(exception),
)
) )
) protocol.handle_exception(exception)
protocol.handle_exception(exception)
try: try:
del self.application_instances[protocol] del self.application_instances[protocol]
except KeyError: except KeyError:

86
daphne/test_utils.py Normal file
View File

@ -0,0 +1,86 @@
import msgpack
import os
import tempfile
class TestApplication:
"""
An application that receives one or more messages, sends a response,
and then quits the server. For testing.
"""
setup_storage = os.path.join(tempfile.gettempdir(), "setup.testio")
result_storage = os.path.join(tempfile.gettempdir(), "result.testio")
def __init__(self, scope):
self.scope = scope
self.messages = []
async def __call__(self, send, receive):
# Load setup info
setup = self.load_setup()
try:
for _ in range(setup["receive_messages"]):
self.messages.append(await receive())
for message in setup["response_messages"]:
await send(message)
finally:
self.save_result()
@classmethod
def save_setup(cls, response_messages, receive_messages=1):
"""
Stores setup information.
"""
with open(cls.setup_storage, "wb") as fh:
fh.write(msgpack.packb(
{
"response_messages": response_messages,
"receive_messages": receive_messages,
},
use_bin_type=True,
))
@classmethod
def load_setup(cls):
"""
Returns setup details.
"""
with open(cls.setup_storage, "rb") as fh:
return msgpack.unpackb(fh.read(), encoding="utf-8")
def save_result(self):
"""
Saves details of what happened to the result storage.
We could use pickle here, but that seems wrong, still, somehow.
"""
with open(self.result_storage, "wb") as fh:
fh.write(msgpack.packb(
{
"scope": self.scope,
"messages": self.messages,
},
use_bin_type=True,
))
@classmethod
def load_result(cls):
"""
Returns result details.
"""
with open(cls.result_storage, "rb") as fh:
return msgpack.unpackb(fh.read(), encoding="utf-8")
@classmethod
def clear_storage(cls):
"""
Clears storage files.
"""
try:
os.unlink(cls.setup_storage)
except OSError:
pass
try:
os.unlink(cls.result_storage)
except OSError:
pass

View File

@ -1,197 +0,0 @@
# coding: utf8
"""
Tests for the HTTP request section of the ASGI spec
"""
from __future__ import unicode_literals
import unittest
from six.moves.urllib import parse
from asgiref.inmemory import ChannelLayer
from hypothesis import given, assume, settings, HealthCheck
from twisted.test import proto_helpers
from daphne.http_protocol import HTTPFactory
from daphne.tests import testcases, http_strategies
from daphne.tests.factories import message_for_request, content_length_header
class TestHTTPRequestSpec(testcases.ASGIHTTPTestCase):
"""
Tests which try to pour the HTTP request section of the ASGI spec into code.
The heavy lifting is done by the assert_valid_http_request_message function,
the tests mostly serve to wire up hypothesis so that it exercise it's power to find
edge cases.
"""
def test_minimal_request(self):
"""
Smallest viable example. Mostly verifies that our request building works.
"""
request_method, request_path = "GET", "/"
message = message_for_request(request_method, request_path)
self.assert_valid_http_request_message(message, request_method, request_path)
@given(
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params()
)
def test_get_request(self, request_path, request_params):
"""
Tests a typical HTTP GET request, with a path and query parameters
"""
request_method = "GET"
message = message_for_request(request_method, request_path, request_params)
self.assert_valid_http_request_message(
message, request_method, request_path, request_params=request_params)
@given(
request_path=http_strategies.http_path(),
request_body=http_strategies.http_body()
)
def test_post_request(self, request_path, request_body):
"""
Tests a typical POST request, submitting some data in a body.
"""
request_method = "POST"
headers = [content_length_header(request_body)]
message = message_for_request(
request_method, request_path, headers=headers, body=request_body)
self.assert_valid_http_request_message(
message, request_method, request_path,
request_headers=headers, request_body=request_body)
@given(request_headers=http_strategies.headers())
def test_headers(self, request_headers):
"""
Tests that HTTP header fields are handled as specified
"""
request_method, request_path = "OPTIONS", "/te st-à/"
message = message_for_request(request_method, request_path, headers=request_headers)
self.assert_valid_http_request_message(
message, request_method, request_path, request_headers=request_headers)
@given(request_headers=http_strategies.headers())
def test_duplicate_headers(self, request_headers):
"""
Tests that duplicate header values are preserved
"""
assume(len(request_headers) >= 2)
# Set all header field names to the same value
header_name = request_headers[0][0]
duplicated_headers = [(header_name, header[1]) for header in request_headers]
request_method, request_path = "OPTIONS", "/te st-à/"
message = message_for_request(request_method, request_path, headers=duplicated_headers)
self.assert_valid_http_request_message(
message, request_method, request_path, request_headers=duplicated_headers)
@given(
request_method=http_strategies.http_method(),
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params(),
request_headers=http_strategies.headers(),
request_body=http_strategies.http_body(),
)
# This test is slow enough that on Travis, hypothesis sometimes complains.
@settings(suppress_health_check=[HealthCheck.too_slow])
def test_kitchen_sink(
self, request_method, request_path, request_params, request_headers, request_body):
"""
Throw everything at channels that we dare. The idea is that if a combination
of method/path/headers/body would break the spec, hypothesis will eventually find it.
"""
request_headers.append(content_length_header(request_body))
message = message_for_request(
request_method, request_path, request_params, request_headers, request_body)
self.assert_valid_http_request_message(
message, request_method, request_path, request_params, request_headers, request_body)
def test_headers_are_lowercased_and_stripped(self):
request_method, request_path = "GET", "/"
headers = [("MYCUSTOMHEADER", " foobar ")]
message = message_for_request(request_method, request_path, headers=headers)
self.assert_valid_http_request_message(
message, request_method, request_path, request_headers=headers)
# Note that Daphne returns a list of tuples here, which is fine, because the spec
# asks to treat them interchangeably.
assert message["headers"] == [(b"mycustomheader", b"foobar")]
@given(daphne_path=http_strategies.http_path())
def test_root_path_header(self, daphne_path):
"""
Tests root_path handling.
"""
request_method, request_path = "GET", "/"
# Daphne-Root-Path must be URL encoded when submitting as HTTP header field
headers = [("Daphne-Root-Path", parse.quote(daphne_path.encode("utf8")))]
message = message_for_request(request_method, request_path, headers=headers)
# Daphne-Root-Path is not included in the returned 'headers' section. So we expect
# empty headers.
expected_headers = []
self.assert_valid_http_request_message(
message, request_method, request_path, request_headers=expected_headers)
# And what we're looking for, root_path being set.
assert message["root_path"] == daphne_path
class TestProxyHandling(unittest.TestCase):
"""
Tests that concern interaction of Daphne with proxies.
They live in a separate test case, because they're not part of the spec.
"""
def setUp(self):
self.channel_layer = ChannelLayer()
self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
self.proto = self.factory.buildProtocol(("127.0.0.1", 0))
self.tr = proto_helpers.StringTransport()
self.proto.makeConnection(self.tr)
def test_x_forwarded_for_ignored(self):
self.proto.dataReceived(
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
b"Host: somewhere.com\r\n" +
b"X-Forwarded-For: 10.1.2.3\r\n" +
b"X-Forwarded-Port: 80\r\n" +
b"\r\n"
)
# Get the resulting message off of the channel layer
_, message = self.channel_layer.receive(["http.request"])
self.assertEqual(message["client"], ["192.168.1.1", 54321])
def test_x_forwarded_for_parsed(self):
self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
self.proto.dataReceived(
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
b"Host: somewhere.com\r\n" +
b"X-Forwarded-For: 10.1.2.3\r\n" +
b"X-Forwarded-Port: 80\r\n" +
b"\r\n"
)
# Get the resulting message off of the channel layer
_, message = self.channel_layer.receive(["http.request"])
self.assertEqual(message["client"], ["10.1.2.3", 80])
def test_x_forwarded_for_port_missing(self):
self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
self.proto.dataReceived(
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
b"Host: somewhere.com\r\n" +
b"X-Forwarded-For: 10.1.2.3\r\n" +
b"\r\n"
)
# Get the resulting message off of the channel layer
_, message = self.channel_layer.receive(["http.request"])
self.assertEqual(message["client"], ["10.1.2.3", 0])

View File

@ -1,238 +0,0 @@
"""
Contains a test case class to allow verifying ASGI messages
"""
from __future__ import unicode_literals
from collections import defaultdict
from urllib import parse
import socket
import unittest
from . import factories
class ASGITestCaseBase(unittest.TestCase):
"""
Base class for our test classes which contains shared method.
"""
def assert_is_ip_address(self, address):
"""
Tests whether a given address string is a valid IPv4 or IPv6 address.
"""
try:
socket.inet_aton(address)
except socket.error:
self.fail("'%s' is not a valid IP address." % address)
def assert_presence_of_message_keys(self, keys, required_keys, optional_keys):
present_keys = set(keys)
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
self.assertEqual(set(), present_keys - required_keys - optional_keys)
def assert_valid_reply_channel(self, reply_channel):
self.assertIsInstance(reply_channel, str)
# The reply channel is decided by the server.
self.assertTrue(reply_channel.startswith("test!"))
def assert_valid_path(self, path, request_path):
self.assertIsInstance(path, str)
self.assertEqual(path, request_path)
# Assert that it's already url decoded
self.assertEqual(path, parse.unquote(path))
def assert_valid_address_and_port(self, host):
address, port = host
self.assertIsInstance(address, str)
self.assert_is_ip_address(address)
self.assertIsInstance(port, int)
class ASGIHTTPTestCase(ASGITestCaseBase):
"""
Test case with helpers for verifying HTTP channel messages
"""
def assert_valid_http_request_message(
self, channel_message, request_method, request_path,
request_params=None, request_headers=None, request_body=None):
"""
Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
"""
self.assertTrue(channel_message)
self.assert_presence_of_message_keys(
channel_message.keys(),
{"reply_channel", "http_version", "method", "path", "query_string", "headers"},
{"scheme", "root_path", "body", "body_channel", "client", "server"})
# == Assertions about required channel_message fields ==
self.assert_valid_reply_channel(channel_message["reply_channel"])
self.assert_valid_path(channel_message["path"], request_path)
http_version = channel_message["http_version"]
self.assertIsInstance(http_version, str)
self.assertIn(http_version, ["1.0", "1.1", "1.2"])
method = channel_message["method"]
self.assertIsInstance(method, str)
self.assertTrue(method.isupper())
self.assertEqual(channel_message["method"], request_method)
query_string = channel_message["query_string"]
# Assert that query_string is a byte string and still url encoded
self.assertIsInstance(query_string, bytes)
self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
# {name: [value1, value2, ...]} and check if they're equal.
transformed_message_headers = defaultdict(list)
for name, value in channel_message["headers"]:
transformed_message_headers[name].append(value)
transformed_request_headers = defaultdict(list)
for name, value in (request_headers or []):
expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii")
transformed_request_headers[expected_name].append(expected_value)
self.assertEqual(transformed_message_headers, transformed_request_headers)
# == Assertions about optional channel_message fields ==
scheme = channel_message.get("scheme")
if scheme is not None:
self.assertIsInstance(scheme, str)
self.assertTrue(scheme) # May not be empty
root_path = channel_message.get("root_path")
if root_path is not None:
self.assertIsInstance(root_path, str)
body = channel_message.get("body")
# Ensure we test for presence of 'body' if a request body was given
if request_body is not None or body is not None:
self.assertIsInstance(body, str)
self.assertEqual(body, (request_body or "").encode("ascii"))
body_channel = channel_message.get("body_channel")
if body_channel is not None:
self.assertIsInstance(body_channel, str)
self.assertIn("?", body_channel)
client = channel_message.get("client")
if client is not None:
self.assert_valid_address_and_port(channel_message["client"])
server = channel_message.get("server")
if server is not None:
self.assert_valid_address_and_port(channel_message["server"])
def assert_valid_http_response_message(self, message, response):
self.assertTrue(message)
self.assertTrue(response.startswith(b"HTTP"))
status_code_bytes = str(message["status"]).encode("ascii")
self.assertIn(status_code_bytes, response)
if "content" in message:
self.assertIn(message["content"], response)
# Check that headers are in the given order.
# N.b. HTTP spec only enforces that the order of header values is kept, but
# the ASGI spec requires that order of all headers is kept. This code
# checks conformance with the stricter ASGI spec.
if "headers" in message:
for name, value in message["headers"]:
expected_header = factories.header_line(name, value)
# Daphne or Twisted turn our lower cased header names ('foo-bar') into title
# case ('Foo-Bar'). So technically we want to to match that the header name is
# present while ignoring casing, and want to ensure the value is present without
# altered casing. The approach below does this well enough.
self.assertIn(expected_header.lower(), response.lower())
self.assertIn(value.encode("ascii"), response)
class ASGIWebSocketTestCase(ASGITestCaseBase):
"""
Test case with helpers for verifying WebSocket channel messages
"""
def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
self.assertIn(body, response)
self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
def assert_websocket_denied(self, response):
self.assertIn(b"HTTP/1.1 403", response)
def assert_valid_websocket_connect_message(
self, channel_message, request_path="/", request_params=None, request_headers=None):
"""
Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
"""
self.assertTrue(channel_message)
self.assert_presence_of_message_keys(
channel_message.keys(),
{"reply_channel", "path", "headers", "order"},
{"scheme", "query_string", "root_path", "client", "server"})
# == Assertions about required channel_message fields ==
self.assert_valid_reply_channel(channel_message["reply_channel"])
self.assert_valid_path(channel_message["path"], request_path)
order = channel_message["order"]
self.assertIsInstance(order, int)
self.assertEqual(order, 0)
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform the request
# headers and the channel message headers into a set
# {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
# Note that unlike for HTTP, Daphne never gives out individual header values; instead we
# get one string per header field with values separated by comma.
transformed_request_headers = defaultdict(list)
for name, value in (request_headers or []):
expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii")
transformed_request_headers[expected_name].append(expected_value)
final_request_headers = {
(name, b",".join(value)) for name, value in transformed_request_headers.items()
}
# Websockets carry a lot of additional header fields, so instead of verifying that
# headers look exactly like expected, we just check that the expected header fields
# and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
# and not tested for.
assert final_request_headers.issubset(set(channel_message["headers"]))
# == Assertions about optional channel_message fields ==
scheme = channel_message.get("scheme")
if scheme:
self.assertIsInstance(scheme, six.text_type)
self.assertIn(scheme, ["ws", "wss"])
query_string = channel_message.get("query_string")
if query_string:
# Assert that query_string is a byte string and still url encoded
self.assertIsInstance(query_string, six.binary_type)
self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
root_path = channel_message.get("root_path")
if root_path is not None:
self.assertIsInstance(root_path, six.text_type)
client = channel_message.get("client")
if client is not None:
self.assert_valid_address_and_port(channel_message["client"])
server = channel_message.get("server")
if server is not None:
self.assert_valid_address_and_port(channel_message["server"])

View File

@ -3,12 +3,3 @@ universal=1
[tool:pytest] [tool:pytest]
addopts = tests/ addopts = tests/
[yapf]
based_on_style = pep8
column_limit = 120
join_multiple_lines = false
split_arguments_when_comma_terminated = true
split_before_expression_after_opening_paren = true
split_before_first_argument = true
split_penalty_after_opening_bracket = -10

248
tests/http_base.py Normal file
View File

@ -0,0 +1,248 @@
from urllib import parse
import requests
import socket
import subprocess
import time
import unittest
from daphne.test_utils import TestApplication
class DaphneTestCase(unittest.TestCase):
"""
Base class for Daphne integration test cases.
Boots up a copy of Daphne on a test port and sends it a request, and
retrieves the response. Uses a custom ASGI application and temporary files
to store/retrieve the request/response messages.
"""
def port_in_use(self, port):
"""
Tests if a port is in use on the local machine.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.bind(("127.0.0.1", port))
except socket.error as e:
if e.errno in [13, 98]:
return True
else:
raise
else:
return False
finally:
s.close()
def run_daphne(self, method, path, params, data, responses, headers=None, timeout=1):
"""
Runs Daphne with the given request callback (given the base URL)
and response messages.
"""
# Store setup info
TestApplication.clear_storage()
TestApplication.save_setup(
response_messages=responses,
)
# Find a free port
for i in range(11200, 11300):
if not self.port_in_use(i):
port = i
break
else:
raise RuntimeError("Cannot find a free port to test on")
# Launch daphne on that port
process = subprocess.Popen(["daphne", "-p", str(port), "daphne.test_utils:TestApplication"])
try:
for _ in range(100):
time.sleep(0.1)
if self.port_in_use(port):
break
else:
raise RuntimeError("Daphne never came up.")
# Send it the request
url = "http://127.0.0.1:%i%s" % (port, path)
response = requests.request(method, url, params=params, data=data, headers=headers, timeout=timeout)
finally:
# Shut down daphne
process.terminate()
# Load the information
inner_result = TestApplication.load_result()
# Return the inner result and the response
return inner_result, response
def run_daphne_request(self, method, path, params=None, data=None, headers=None):
"""
Convenience method for just testing request handling.
Returns (scope, messages)
"""
if headers is not None:
headers = dict(headers)
inner_result, _ = self.run_daphne(
method=method,
path=path,
params=params,
data=data,
headers=headers,
responses=[{"type": "http.response", "status": 200, "content": b"OK"}],
)
return inner_result["scope"], inner_result["messages"]
def tearDown(self):
"""
Ensures any storage files are cleared.
"""
TestApplication.clear_storage()
def assert_is_ip_address(self, address):
"""
Tests whether a given address string is a valid IPv4 or IPv6 address.
"""
try:
socket.inet_aton(address)
except socket.error:
self.fail("'%s' is not a valid IP address." % address)
def assert_key_sets(self, required_keys, optional_keys, actual_keys):
"""
Asserts that all required_keys are in actual_keys, and that there
are no keys in actual_keys that aren't required or optional.
"""
present_keys = set(actual_keys)
# Make sure all required keys are present
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
self.assertEqual(
set(),
present_keys - required_keys - optional_keys,
)
def assert_valid_path(self, path, request_path):
"""
Checks the path is valid and already url-decoded.
"""
self.assertIsInstance(path, str)
self.assertEqual(path, request_path)
# Assert that it's already url decoded
self.assertEqual(path, parse.unquote(path))
def assert_valid_address_and_port(self, host):
"""
Asserts the value is a valid (host, port) tuple.
"""
address, port = host
self.assertIsInstance(address, str)
self.assert_is_ip_address(address)
self.assertIsInstance(port, int)
# class ASGIHTTPTestCase(ASGITestCaseBase):
# """
# Test case with helpers for verifying HTTP channel messages
# """
# def assert_valid_http_response_message(self, message, response):
# self.assertTrue(message)
# self.assertTrue(response.startswith(b"HTTP"))
# status_code_bytes = str(message["status"]).encode("ascii")
# self.assertIn(status_code_bytes, response)
# if "content" in message:
# self.assertIn(message["content"], response)
# # Check that headers are in the given order.
# # N.b. HTTP spec only enforces that the order of header values is kept, but
# # the ASGI spec requires that order of all headers is kept. This code
# # checks conformance with the stricter ASGI spec.
# if "headers" in message:
# for name, value in message["headers"]:
# expected_header = factories.header_line(name, value)
# # Daphne or Twisted turn our lower cased header names ('foo-bar') into title
# # case ('Foo-Bar'). So technically we want to to match that the header name is
# # present while ignoring casing, and want to ensure the value is present without
# # altered casing. The approach below does this well enough.
# self.assertIn(expected_header.lower(), response.lower())
# self.assertIn(value.encode("ascii"), response)
# class ASGIWebSocketTestCase(ASGITestCaseBase):
# """
# Test case with helpers for verifying WebSocket channel messages
# """
# def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
# self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
# self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
# self.assertIn(body, response)
# self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
# def assert_websocket_denied(self, response):
# self.assertIn(b"HTTP/1.1 403", response)
# def assert_valid_websocket_connect_message(
# self, channel_message, request_path="/", request_params=None, request_headers=None):
# """
# Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
# """
# self.assertTrue(channel_message)
# self.assert_presence_of_message_keys(
# channel_message.keys(),
# {"reply_channel", "path", "headers", "order"},
# {"scheme", "query_string", "root_path", "client", "server"})
# # == Assertions about required channel_message fields ==
# self.assert_valid_reply_channel(channel_message["reply_channel"])
# self.assert_valid_path(channel_message["path"], request_path)
# order = channel_message["order"]
# self.assertIsInstance(order, int)
# self.assertEqual(order, 0)
# # Ordering of header names is not important, but the order of values for a header
# # name is. To assert whether that order is kept, we transform the request
# # headers and the channel message headers into a set
# # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
# # Note that unlike for HTTP, Daphne never gives out individual header values; instead we
# # get one string per header field with values separated by comma.
# transformed_request_headers = defaultdict(list)
# for name, value in (request_headers or []):
# expected_name = name.lower().strip().encode("ascii")
# expected_value = value.strip().encode("ascii")
# transformed_request_headers[expected_name].append(expected_value)
# final_request_headers = {
# (name, b",".join(value)) for name, value in transformed_request_headers.items()
# }
# # Websockets carry a lot of additional header fields, so instead of verifying that
# # headers look exactly like expected, we just check that the expected header fields
# # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
# # and not tested for.
# assert final_request_headers.issubset(set(channel_message["headers"]))
# # == Assertions about optional channel_message fields ==
# scheme = channel_message.get("scheme")
# if scheme:
# self.assertIsInstance(scheme, six.text_type)
# self.assertIn(scheme, ["ws", "wss"])
# query_string = channel_message.get("query_string")
# if query_string:
# # Assert that query_string is a byte string and still url encoded
# self.assertIsInstance(query_string, six.binary_type)
# self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
# root_path = channel_message.get("root_path")
# if root_path is not None:
# self.assertIsInstance(root_path, six.text_type)
# client = channel_message.get("client")
# if client is not None:
# self.assert_valid_address_and_port(channel_message["client"])
# server = channel_message.get("server")
# if server is not None:
# self.assert_valid_address_and_port(channel_message["server"])

View File

@ -27,18 +27,16 @@ def http_path():
Returns a URL path (not encoded). Returns a URL path (not encoded).
""" """
return strategies.lists( return strategies.lists(
_http_path_portion(), min_size=0, max_size=10).map(lambda s: "/" + "/".join(s)) _http_path_portion(),
min_size=0,
max_size=10,
).map(lambda s: "/" + "/".join(s))
def http_body(): def http_body():
""" """
Returns random printable ASCII characters. This may be exceeding what HTTP allows, Returns random binary body data.
but seems to not cause an issue so far.
""" """
return strategies.text(alphabet=string.printable, min_size=0, average_size=600, max_size=1500)
def binary_payload():
return strategies.binary(min_size=0, average_size=600, max_size=1500) return strategies.binary(min_size=0, average_size=600, max_size=1500)
@ -59,7 +57,11 @@ def valid_bidi(value):
def _domain_label(): def _domain_label():
return strategies.text( return strategies.text(
alphabet=letters, min_size=1, average_size=6, max_size=63).filter(valid_bidi) alphabet=letters,
min_size=1,
average_size=6,
max_size=63,
).filter(valid_bidi)
def international_domain_name(): def international_domain_name():
@ -67,12 +69,19 @@ def international_domain_name():
Returns a byte string of a domain name, IDNA-encoded. Returns a byte string of a domain name, IDNA-encoded.
""" """
return strategies.lists( return strategies.lists(
_domain_label(), min_size=2, average_size=2).map(lambda s: (".".join(s)).encode("idna")) _domain_label(),
min_size=2,
average_size=2,
).map(lambda s: (".".join(s)).encode("idna"))
def _query_param(): def _query_param():
return strategies.text(alphabet=letters, min_size=1, average_size=10, max_size=255).\ return strategies.text(
map(lambda s: s.encode("utf8")) alphabet=letters,
min_size=1,
average_size=10,
max_size=255,
).map(lambda s: s.encode("utf8"))
def query_params(): def query_params():
@ -82,8 +91,10 @@ def query_params():
ensures that the total urlencoded query string is not longer than 1500 characters. ensures that the total urlencoded query string is not longer than 1500 characters.
""" """
return strategies.lists( return strategies.lists(
strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5).\ strategies.tuples(_query_param(), _query_param()),
filter(lambda x: len(parse.urlencode(x)) < 1500) min_size=0,
average_size=5,
).filter(lambda x: len(parse.urlencode(x)) < 1500)
def header_name(): def header_name():
@ -94,7 +105,10 @@ def header_name():
and 20 characters long and 20 characters long
""" """
return strategies.text( return strategies.text(
alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30) alphabet=string.ascii_letters + string.digits + "-",
min_size=1,
max_size=30,
)
def header_value(): def header_value():
@ -106,7 +120,10 @@ def header_value():
""" """
return strategies.text( return strategies.text(
alphabet=string.ascii_letters + string.digits + string.punctuation + " /t", alphabet=string.ascii_letters + string.digits + string.punctuation + " /t",
min_size=1, average_size=40, max_size=8190).filter(lambda s: len(s.encode("utf8")) < 8190) min_size=1,
average_size=40,
max_size=8190,
).filter(lambda s: len(s.encode("utf8")) < 8190)
def headers(): def headers():
@ -118,4 +135,7 @@ def headers():
""" """
return strategies.lists( return strategies.lists(
strategies.tuples(header_name(), header_value()), strategies.tuples(header_name(), header_value()),
min_size=0, average_size=10, max_size=100) min_size=0,
average_size=10,
max_size=100,
)

267
tests/test_http_request.py Normal file
View File

@ -0,0 +1,267 @@
# coding: utf8
import collections
from urllib import parse
from hypothesis import given, assume, settings, HealthCheck
import http_strategies
from http_base import DaphneTestCase
class TestHTTPRequestSpec(DaphneTestCase):
"""
Tests which try to pour the HTTP request section of the ASGI spec into code.
The heavy lifting is done by the assert_valid_http_request_message function,
the tests mostly serve to wire up hypothesis so that it exercise it's power to find
edge cases.
"""
def assert_valid_http_scope(
self,
scope,
method,
path,
params=None,
headers=None,
scheme=None,
):
"""
Checks that the passed scope is a valid ASGI HTTP scope regarding types
and some urlencoding things.
"""
# Check overall keys
self.assert_key_sets(
required_keys={"type", "http_version", "method", "path", "query_string", "headers"},
optional_keys={"scheme", "root_path", "client", "server"},
actual_keys=scope.keys(),
)
# Check that it is the right type
self.assertEqual(scope["type"], "http")
# Method (uppercased unicode string)
self.assertIsInstance(scope["method"], str)
self.assertEqual(scope["method"], method.upper())
# Path
self.assert_valid_path(scope["path"], path)
# HTTP version
self.assertIn(scope["http_version"], ["1.0", "1.1", "1.2"])
# Scheme
self.assertIn(scope["scheme"], ["http", "https"])
if scheme:
self.assertEqual(scheme, scope["scheme"])
# Query string (byte string and still url encoded)
query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes)
if params:
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
# {name: [value1, value2, ...]} and check if they're equal.
transformed_scope_headers = collections.defaultdict(list)
for name, value in scope["headers"]:
transformed_scope_headers[name].append(value)
transformed_request_headers = collections.defaultdict(list)
for name, value in (headers or []):
expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii")
transformed_request_headers[expected_name].append(expected_value)
for name, value in transformed_request_headers.items():
self.assertIn(name, transformed_scope_headers)
self.assertEqual(value, transformed_scope_headers[name])
# Root path
self.assertIsInstance(scope.get("root_path", ""), str)
# Client and server addresses
client = scope.get("client")
if client is not None:
self.assert_valid_address_and_port(client)
server = scope.get("server")
if server is not None:
self.assert_valid_address_and_port(server)
def assert_valid_http_request_message(
self,
message,
body=None,
):
"""
Asserts that a message is a valid http.request message
"""
# Check overall keys
self.assert_key_sets(
required_keys={"type"},
optional_keys={"body", "more_content"},
actual_keys=message.keys(),
)
# Check that it is the right type
self.assertEqual(message["type"], "http.request")
# If there's a body present, check its type
self.assertIsInstance(message.get("body", b""), bytes)
if body is not None:
self.assertEqual(body, message.get("body", b""))
def test_minimal_request(self):
"""
Smallest viable example. Mostly verifies that our request building works.
"""
scope, messages = self.run_daphne_request("GET", "/")
self.assert_valid_http_scope(scope, "GET", "/")
self.assert_valid_http_request_message(messages[0], body=b"")
@given(
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params()
)
@settings(max_examples=5, deadline=2000)
def test_get_request(self, request_path, request_params):
"""
Tests a typical HTTP GET request, with a path and query parameters
"""
scope, messages = self.run_daphne_request("GET", request_path, params=request_params)
self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
self.assert_valid_http_request_message(messages[0], body=b"")
@given(
request_path=http_strategies.http_path(),
request_body=http_strategies.http_body()
)
@settings(max_examples=5, deadline=2000)
def test_post_request(self, request_path, request_body):
"""
Tests a typical HTTP POST request, with a path and body.
"""
scope, messages = self.run_daphne_request("POST", request_path, data=request_body)
self.assert_valid_http_scope(scope, "POST", request_path)
self.assert_valid_http_request_message(messages[0], body=request_body)
@given(request_headers=http_strategies.headers())
@settings(max_examples=5, deadline=2000)
def test_headers(self, request_headers):
"""
Tests that HTTP header fields are handled as specified
"""
request_path = "/te st-à/"
scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers)
self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers)
self.assert_valid_http_request_message(messages[0], body=b"")
# @given(request_headers=http_strategies.headers())
# def test_duplicate_headers(self, request_headers):
# """
# Tests that duplicate header values are preserved
# """
# assume(len(request_headers) >= 2)
# # Set all header field names to the same value
# header_name = request_headers[0][0]
# duplicated_headers = [(header_name, header[1]) for header in request_headers]
# request_method, request_path = "OPTIONS", "/te st-à/"
# message = message_for_request(request_method, request_path, headers=duplicated_headers)
# self.assert_valid_http_request_message(
# message, request_method, request_path, request_headers=duplicated_headers)
# @given(
# request_method=http_strategies.http_method(),
# request_path=http_strategies.http_path(),
# request_params=http_strategies.query_params(),
# request_headers=http_strategies.headers(),
# request_body=http_strategies.http_body(),
# )
# # This test is slow enough that on Travis, hypothesis sometimes complains.
# @settings(suppress_health_check=[HealthCheck.too_slow])
# def test_kitchen_sink(
# self, request_method, request_path, request_params, request_headers, request_body):
# """
# Throw everything at channels that we dare. The idea is that if a combination
# of method/path/headers/body would break the spec, hypothesis will eventually find it.
# """
# request_headers.append(content_length_header(request_body))
# message = message_for_request(
# request_method, request_path, request_params, request_headers, request_body)
# self.assert_valid_http_request_message(
# message, request_method, request_path, request_params, request_headers, request_body)
# def test_headers_are_lowercased_and_stripped(self):
# request_method, request_path = "GET", "/"
# headers = [("MYCUSTOMHEADER", " foobar ")]
# message = message_for_request(request_method, request_path, headers=headers)
# self.assert_valid_http_request_message(
# message, request_method, request_path, request_headers=headers)
# # Note that Daphne returns a list of tuples here, which is fine, because the spec
# # asks to treat them interchangeably.
# assert message["headers"] == [(b"mycustomheader", b"foobar")]
# @given(daphne_path=http_strategies.http_path())
# def test_root_path_header(self, daphne_path):
# """
# Tests root_path handling.
# """
# request_method, request_path = "GET", "/"
# # Daphne-Root-Path must be URL encoded when submitting as HTTP header field
# headers = [("Daphne-Root-Path", parse.quote(daphne_path.encode("utf8")))]
# message = message_for_request(request_method, request_path, headers=headers)
# # Daphne-Root-Path is not included in the returned 'headers' section. So we expect
# # empty headers.
# expected_headers = []
# self.assert_valid_http_request_message(
# message, request_method, request_path, request_headers=expected_headers)
# # And what we're looking for, root_path being set.
# assert message["root_path"] == daphne_path
# class TestProxyHandling(unittest.TestCase):
# """
# Tests that concern interaction of Daphne with proxies.
# They live in a separate test case, because they're not part of the spec.
# """
# def setUp(self):
# self.channel_layer = ChannelLayer()
# self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
# self.proto = self.factory.buildProtocol(("127.0.0.1", 0))
# self.tr = proto_helpers.StringTransport()
# self.proto.makeConnection(self.tr)
# def test_x_forwarded_for_ignored(self):
# self.proto.dataReceived(
# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
# b"Host: somewhere.com\r\n" +
# b"X-Forwarded-For: 10.1.2.3\r\n" +
# b"X-Forwarded-Port: 80\r\n" +
# b"\r\n"
# )
# # Get the resulting message off of the channel layer
# _, message = self.channel_layer.receive(["http.request"])
# self.assertEqual(message["client"], ["192.168.1.1", 54321])
# def test_x_forwarded_for_parsed(self):
# self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
# self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
# self.proto.dataReceived(
# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
# b"Host: somewhere.com\r\n" +
# b"X-Forwarded-For: 10.1.2.3\r\n" +
# b"X-Forwarded-Port: 80\r\n" +
# b"\r\n"
# )
# # Get the resulting message off of the channel layer
# _, message = self.channel_layer.receive(["http.request"])
# self.assertEqual(message["client"], ["10.1.2.3", 80])
# def test_x_forwarded_for_port_missing(self):
# self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
# self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
# self.proto.dataReceived(
# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
# b"Host: somewhere.com\r\n" +
# b"X-Forwarded-For: 10.1.2.3\r\n" +
# b"\r\n"
# )
# # Get the resulting message off of the channel layer
# _, message = self.channel_layer.receive(["http.request"])
# self.assertEqual(message["client"], ["10.1.2.3", 0])

View File

@ -1,11 +1,9 @@
# coding: utf8 # coding: utf8
from __future__ import unicode_literals
from unittest import TestCase
import six
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from unittest import TestCase
from ..utils import parse_x_forwarded_for from daphne.utils import parse_x_forwarded_for
class TestXForwardedForHttpParsing(TestCase): class TestXForwardedForHttpParsing(TestCase):
@ -20,7 +18,7 @@ class TestXForwardedForHttpParsing(TestCase):
}) })
result = parse_x_forwarded_for(headers) result = parse_x_forwarded_for(headers)
self.assertEqual(result, ["10.1.2.3", 1234]) self.assertEqual(result, ["10.1.2.3", 1234])
self.assertIsInstance(result[0], six.text_type) self.assertIsInstance(result[0], str)
def test_address_only(self): def test_address_only(self):
headers = Headers({ headers = Headers({
@ -94,7 +92,7 @@ class TestXForwardedForWsParsing(TestCase):
["1043::a321:0001", 0] ["1043::a321:0001", 0]
) )
def test_multiple_proxys(self): def test_multiple_proxies(self):
headers = { headers = {
b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4", b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4",
} }