HTTP protocol tests

This commit is contained in:
Andrew Godwin 2017-11-25 18:23:54 -08:00
parent 0626f39214
commit b72349d2c1
10 changed files with 661 additions and 481 deletions

View File

@ -127,12 +127,7 @@ class WebRequest(http.Request):
# Remove our HTTP reply channel association
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
# Resume the producer so we keep getting data, if it's available as a method
# 17.1 version
if hasattr(self.channel, "_networkProducer"):
self.channel._networkProducer.resumeProducing()
# 16.x version
elif hasattr(self.channel, "resumeProducing"):
self.channel.resumeProducing()
self.channel._networkProducer.resumeProducing()
# Boring old HTTP.
else:

View File

@ -106,6 +106,12 @@ class Server(object):
reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications)
reactor.run(installSignalHandlers=self.signal_handlers)
def stop(self):
"""
Force-stops the server.
"""
reactor.stop()
### Protocol handling
def add_protocol(self, protocol):
@ -159,16 +165,20 @@ class Server(object):
if application_instance.done():
exception = application_instance.exception()
if exception:
logging.error(
"Exception inside application: {}\n{}{}".format(
exception,
"".join(traceback.format_tb(
exception.__traceback__,
)),
" {}".format(exception),
if isinstance(exception, KeyboardInterrupt):
# Protocol is asking the server to exit (likely during test)
self.stop()
else:
logging.error(
"Exception inside application: {}\n{}{}".format(
exception,
"".join(traceback.format_tb(
exception.__traceback__,
)),
" {}".format(exception),
)
)
)
protocol.handle_exception(exception)
protocol.handle_exception(exception)
try:
del self.application_instances[protocol]
except KeyError:

86
daphne/test_utils.py Normal file
View File

@ -0,0 +1,86 @@
import msgpack
import os
import tempfile
class TestApplication:
"""
An application that receives one or more messages, sends a response,
and then quits the server. For testing.
"""
setup_storage = os.path.join(tempfile.gettempdir(), "setup.testio")
result_storage = os.path.join(tempfile.gettempdir(), "result.testio")
def __init__(self, scope):
self.scope = scope
self.messages = []
async def __call__(self, send, receive):
# Load setup info
setup = self.load_setup()
try:
for _ in range(setup["receive_messages"]):
self.messages.append(await receive())
for message in setup["response_messages"]:
await send(message)
finally:
self.save_result()
@classmethod
def save_setup(cls, response_messages, receive_messages=1):
"""
Stores setup information.
"""
with open(cls.setup_storage, "wb") as fh:
fh.write(msgpack.packb(
{
"response_messages": response_messages,
"receive_messages": receive_messages,
},
use_bin_type=True,
))
@classmethod
def load_setup(cls):
"""
Returns setup details.
"""
with open(cls.setup_storage, "rb") as fh:
return msgpack.unpackb(fh.read(), encoding="utf-8")
def save_result(self):
"""
Saves details of what happened to the result storage.
We could use pickle here, but that seems wrong, still, somehow.
"""
with open(self.result_storage, "wb") as fh:
fh.write(msgpack.packb(
{
"scope": self.scope,
"messages": self.messages,
},
use_bin_type=True,
))
@classmethod
def load_result(cls):
"""
Returns result details.
"""
with open(cls.result_storage, "rb") as fh:
return msgpack.unpackb(fh.read(), encoding="utf-8")
@classmethod
def clear_storage(cls):
"""
Clears storage files.
"""
try:
os.unlink(cls.setup_storage)
except OSError:
pass
try:
os.unlink(cls.result_storage)
except OSError:
pass

View File

@ -1,197 +0,0 @@
# coding: utf8
"""
Tests for the HTTP request section of the ASGI spec
"""
from __future__ import unicode_literals
import unittest
from six.moves.urllib import parse
from asgiref.inmemory import ChannelLayer
from hypothesis import given, assume, settings, HealthCheck
from twisted.test import proto_helpers
from daphne.http_protocol import HTTPFactory
from daphne.tests import testcases, http_strategies
from daphne.tests.factories import message_for_request, content_length_header
class TestHTTPRequestSpec(testcases.ASGIHTTPTestCase):
"""
Tests which try to pour the HTTP request section of the ASGI spec into code.
The heavy lifting is done by the assert_valid_http_request_message function,
the tests mostly serve to wire up hypothesis so that it exercise it's power to find
edge cases.
"""
def test_minimal_request(self):
"""
Smallest viable example. Mostly verifies that our request building works.
"""
request_method, request_path = "GET", "/"
message = message_for_request(request_method, request_path)
self.assert_valid_http_request_message(message, request_method, request_path)
@given(
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params()
)
def test_get_request(self, request_path, request_params):
"""
Tests a typical HTTP GET request, with a path and query parameters
"""
request_method = "GET"
message = message_for_request(request_method, request_path, request_params)
self.assert_valid_http_request_message(
message, request_method, request_path, request_params=request_params)
@given(
request_path=http_strategies.http_path(),
request_body=http_strategies.http_body()
)
def test_post_request(self, request_path, request_body):
"""
Tests a typical POST request, submitting some data in a body.
"""
request_method = "POST"
headers = [content_length_header(request_body)]
message = message_for_request(
request_method, request_path, headers=headers, body=request_body)
self.assert_valid_http_request_message(
message, request_method, request_path,
request_headers=headers, request_body=request_body)
@given(request_headers=http_strategies.headers())
def test_headers(self, request_headers):
"""
Tests that HTTP header fields are handled as specified
"""
request_method, request_path = "OPTIONS", "/te st-à/"
message = message_for_request(request_method, request_path, headers=request_headers)
self.assert_valid_http_request_message(
message, request_method, request_path, request_headers=request_headers)
@given(request_headers=http_strategies.headers())
def test_duplicate_headers(self, request_headers):
"""
Tests that duplicate header values are preserved
"""
assume(len(request_headers) >= 2)
# Set all header field names to the same value
header_name = request_headers[0][0]
duplicated_headers = [(header_name, header[1]) for header in request_headers]
request_method, request_path = "OPTIONS", "/te st-à/"
message = message_for_request(request_method, request_path, headers=duplicated_headers)
self.assert_valid_http_request_message(
message, request_method, request_path, request_headers=duplicated_headers)
@given(
request_method=http_strategies.http_method(),
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params(),
request_headers=http_strategies.headers(),
request_body=http_strategies.http_body(),
)
# This test is slow enough that on Travis, hypothesis sometimes complains.
@settings(suppress_health_check=[HealthCheck.too_slow])
def test_kitchen_sink(
self, request_method, request_path, request_params, request_headers, request_body):
"""
Throw everything at channels that we dare. The idea is that if a combination
of method/path/headers/body would break the spec, hypothesis will eventually find it.
"""
request_headers.append(content_length_header(request_body))
message = message_for_request(
request_method, request_path, request_params, request_headers, request_body)
self.assert_valid_http_request_message(
message, request_method, request_path, request_params, request_headers, request_body)
def test_headers_are_lowercased_and_stripped(self):
request_method, request_path = "GET", "/"
headers = [("MYCUSTOMHEADER", " foobar ")]
message = message_for_request(request_method, request_path, headers=headers)
self.assert_valid_http_request_message(
message, request_method, request_path, request_headers=headers)
# Note that Daphne returns a list of tuples here, which is fine, because the spec
# asks to treat them interchangeably.
assert message["headers"] == [(b"mycustomheader", b"foobar")]
@given(daphne_path=http_strategies.http_path())
def test_root_path_header(self, daphne_path):
"""
Tests root_path handling.
"""
request_method, request_path = "GET", "/"
# Daphne-Root-Path must be URL encoded when submitting as HTTP header field
headers = [("Daphne-Root-Path", parse.quote(daphne_path.encode("utf8")))]
message = message_for_request(request_method, request_path, headers=headers)
# Daphne-Root-Path is not included in the returned 'headers' section. So we expect
# empty headers.
expected_headers = []
self.assert_valid_http_request_message(
message, request_method, request_path, request_headers=expected_headers)
# And what we're looking for, root_path being set.
assert message["root_path"] == daphne_path
class TestProxyHandling(unittest.TestCase):
"""
Tests that concern interaction of Daphne with proxies.
They live in a separate test case, because they're not part of the spec.
"""
def setUp(self):
self.channel_layer = ChannelLayer()
self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
self.proto = self.factory.buildProtocol(("127.0.0.1", 0))
self.tr = proto_helpers.StringTransport()
self.proto.makeConnection(self.tr)
def test_x_forwarded_for_ignored(self):
self.proto.dataReceived(
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
b"Host: somewhere.com\r\n" +
b"X-Forwarded-For: 10.1.2.3\r\n" +
b"X-Forwarded-Port: 80\r\n" +
b"\r\n"
)
# Get the resulting message off of the channel layer
_, message = self.channel_layer.receive(["http.request"])
self.assertEqual(message["client"], ["192.168.1.1", 54321])
def test_x_forwarded_for_parsed(self):
self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
self.proto.dataReceived(
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
b"Host: somewhere.com\r\n" +
b"X-Forwarded-For: 10.1.2.3\r\n" +
b"X-Forwarded-Port: 80\r\n" +
b"\r\n"
)
# Get the resulting message off of the channel layer
_, message = self.channel_layer.receive(["http.request"])
self.assertEqual(message["client"], ["10.1.2.3", 80])
def test_x_forwarded_for_port_missing(self):
self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
self.proto.dataReceived(
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
b"Host: somewhere.com\r\n" +
b"X-Forwarded-For: 10.1.2.3\r\n" +
b"\r\n"
)
# Get the resulting message off of the channel layer
_, message = self.channel_layer.receive(["http.request"])
self.assertEqual(message["client"], ["10.1.2.3", 0])

View File

@ -1,238 +0,0 @@
"""
Contains a test case class to allow verifying ASGI messages
"""
from __future__ import unicode_literals
from collections import defaultdict
from urllib import parse
import socket
import unittest
from . import factories
class ASGITestCaseBase(unittest.TestCase):
"""
Base class for our test classes which contains shared method.
"""
def assert_is_ip_address(self, address):
"""
Tests whether a given address string is a valid IPv4 or IPv6 address.
"""
try:
socket.inet_aton(address)
except socket.error:
self.fail("'%s' is not a valid IP address." % address)
def assert_presence_of_message_keys(self, keys, required_keys, optional_keys):
present_keys = set(keys)
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
self.assertEqual(set(), present_keys - required_keys - optional_keys)
def assert_valid_reply_channel(self, reply_channel):
self.assertIsInstance(reply_channel, str)
# The reply channel is decided by the server.
self.assertTrue(reply_channel.startswith("test!"))
def assert_valid_path(self, path, request_path):
self.assertIsInstance(path, str)
self.assertEqual(path, request_path)
# Assert that it's already url decoded
self.assertEqual(path, parse.unquote(path))
def assert_valid_address_and_port(self, host):
address, port = host
self.assertIsInstance(address, str)
self.assert_is_ip_address(address)
self.assertIsInstance(port, int)
class ASGIHTTPTestCase(ASGITestCaseBase):
"""
Test case with helpers for verifying HTTP channel messages
"""
def assert_valid_http_request_message(
self, channel_message, request_method, request_path,
request_params=None, request_headers=None, request_body=None):
"""
Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
"""
self.assertTrue(channel_message)
self.assert_presence_of_message_keys(
channel_message.keys(),
{"reply_channel", "http_version", "method", "path", "query_string", "headers"},
{"scheme", "root_path", "body", "body_channel", "client", "server"})
# == Assertions about required channel_message fields ==
self.assert_valid_reply_channel(channel_message["reply_channel"])
self.assert_valid_path(channel_message["path"], request_path)
http_version = channel_message["http_version"]
self.assertIsInstance(http_version, str)
self.assertIn(http_version, ["1.0", "1.1", "1.2"])
method = channel_message["method"]
self.assertIsInstance(method, str)
self.assertTrue(method.isupper())
self.assertEqual(channel_message["method"], request_method)
query_string = channel_message["query_string"]
# Assert that query_string is a byte string and still url encoded
self.assertIsInstance(query_string, bytes)
self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
# {name: [value1, value2, ...]} and check if they're equal.
transformed_message_headers = defaultdict(list)
for name, value in channel_message["headers"]:
transformed_message_headers[name].append(value)
transformed_request_headers = defaultdict(list)
for name, value in (request_headers or []):
expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii")
transformed_request_headers[expected_name].append(expected_value)
self.assertEqual(transformed_message_headers, transformed_request_headers)
# == Assertions about optional channel_message fields ==
scheme = channel_message.get("scheme")
if scheme is not None:
self.assertIsInstance(scheme, str)
self.assertTrue(scheme) # May not be empty
root_path = channel_message.get("root_path")
if root_path is not None:
self.assertIsInstance(root_path, str)
body = channel_message.get("body")
# Ensure we test for presence of 'body' if a request body was given
if request_body is not None or body is not None:
self.assertIsInstance(body, str)
self.assertEqual(body, (request_body or "").encode("ascii"))
body_channel = channel_message.get("body_channel")
if body_channel is not None:
self.assertIsInstance(body_channel, str)
self.assertIn("?", body_channel)
client = channel_message.get("client")
if client is not None:
self.assert_valid_address_and_port(channel_message["client"])
server = channel_message.get("server")
if server is not None:
self.assert_valid_address_and_port(channel_message["server"])
def assert_valid_http_response_message(self, message, response):
self.assertTrue(message)
self.assertTrue(response.startswith(b"HTTP"))
status_code_bytes = str(message["status"]).encode("ascii")
self.assertIn(status_code_bytes, response)
if "content" in message:
self.assertIn(message["content"], response)
# Check that headers are in the given order.
# N.b. HTTP spec only enforces that the order of header values is kept, but
# the ASGI spec requires that order of all headers is kept. This code
# checks conformance with the stricter ASGI spec.
if "headers" in message:
for name, value in message["headers"]:
expected_header = factories.header_line(name, value)
# Daphne or Twisted turn our lower cased header names ('foo-bar') into title
# case ('Foo-Bar'). So technically we want to to match that the header name is
# present while ignoring casing, and want to ensure the value is present without
# altered casing. The approach below does this well enough.
self.assertIn(expected_header.lower(), response.lower())
self.assertIn(value.encode("ascii"), response)
class ASGIWebSocketTestCase(ASGITestCaseBase):
"""
Test case with helpers for verifying WebSocket channel messages
"""
def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
self.assertIn(body, response)
self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
def assert_websocket_denied(self, response):
self.assertIn(b"HTTP/1.1 403", response)
def assert_valid_websocket_connect_message(
self, channel_message, request_path="/", request_params=None, request_headers=None):
"""
Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
"""
self.assertTrue(channel_message)
self.assert_presence_of_message_keys(
channel_message.keys(),
{"reply_channel", "path", "headers", "order"},
{"scheme", "query_string", "root_path", "client", "server"})
# == Assertions about required channel_message fields ==
self.assert_valid_reply_channel(channel_message["reply_channel"])
self.assert_valid_path(channel_message["path"], request_path)
order = channel_message["order"]
self.assertIsInstance(order, int)
self.assertEqual(order, 0)
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform the request
# headers and the channel message headers into a set
# {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
# Note that unlike for HTTP, Daphne never gives out individual header values; instead we
# get one string per header field with values separated by comma.
transformed_request_headers = defaultdict(list)
for name, value in (request_headers or []):
expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii")
transformed_request_headers[expected_name].append(expected_value)
final_request_headers = {
(name, b",".join(value)) for name, value in transformed_request_headers.items()
}
# Websockets carry a lot of additional header fields, so instead of verifying that
# headers look exactly like expected, we just check that the expected header fields
# and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
# and not tested for.
assert final_request_headers.issubset(set(channel_message["headers"]))
# == Assertions about optional channel_message fields ==
scheme = channel_message.get("scheme")
if scheme:
self.assertIsInstance(scheme, six.text_type)
self.assertIn(scheme, ["ws", "wss"])
query_string = channel_message.get("query_string")
if query_string:
# Assert that query_string is a byte string and still url encoded
self.assertIsInstance(query_string, six.binary_type)
self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
root_path = channel_message.get("root_path")
if root_path is not None:
self.assertIsInstance(root_path, six.text_type)
client = channel_message.get("client")
if client is not None:
self.assert_valid_address_and_port(channel_message["client"])
server = channel_message.get("server")
if server is not None:
self.assert_valid_address_and_port(channel_message["server"])

View File

@ -3,12 +3,3 @@ universal=1
[tool:pytest]
addopts = tests/
[yapf]
based_on_style = pep8
column_limit = 120
join_multiple_lines = false
split_arguments_when_comma_terminated = true
split_before_expression_after_opening_paren = true
split_before_first_argument = true
split_penalty_after_opening_bracket = -10

248
tests/http_base.py Normal file
View File

@ -0,0 +1,248 @@
from urllib import parse
import requests
import socket
import subprocess
import time
import unittest
from daphne.test_utils import TestApplication
class DaphneTestCase(unittest.TestCase):
"""
Base class for Daphne integration test cases.
Boots up a copy of Daphne on a test port and sends it a request, and
retrieves the response. Uses a custom ASGI application and temporary files
to store/retrieve the request/response messages.
"""
def port_in_use(self, port):
"""
Tests if a port is in use on the local machine.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.bind(("127.0.0.1", port))
except socket.error as e:
if e.errno in [13, 98]:
return True
else:
raise
else:
return False
finally:
s.close()
def run_daphne(self, method, path, params, data, responses, headers=None, timeout=1):
"""
Runs Daphne with the given request callback (given the base URL)
and response messages.
"""
# Store setup info
TestApplication.clear_storage()
TestApplication.save_setup(
response_messages=responses,
)
# Find a free port
for i in range(11200, 11300):
if not self.port_in_use(i):
port = i
break
else:
raise RuntimeError("Cannot find a free port to test on")
# Launch daphne on that port
process = subprocess.Popen(["daphne", "-p", str(port), "daphne.test_utils:TestApplication"])
try:
for _ in range(100):
time.sleep(0.1)
if self.port_in_use(port):
break
else:
raise RuntimeError("Daphne never came up.")
# Send it the request
url = "http://127.0.0.1:%i%s" % (port, path)
response = requests.request(method, url, params=params, data=data, headers=headers, timeout=timeout)
finally:
# Shut down daphne
process.terminate()
# Load the information
inner_result = TestApplication.load_result()
# Return the inner result and the response
return inner_result, response
def run_daphne_request(self, method, path, params=None, data=None, headers=None):
"""
Convenience method for just testing request handling.
Returns (scope, messages)
"""
if headers is not None:
headers = dict(headers)
inner_result, _ = self.run_daphne(
method=method,
path=path,
params=params,
data=data,
headers=headers,
responses=[{"type": "http.response", "status": 200, "content": b"OK"}],
)
return inner_result["scope"], inner_result["messages"]
def tearDown(self):
"""
Ensures any storage files are cleared.
"""
TestApplication.clear_storage()
def assert_is_ip_address(self, address):
"""
Tests whether a given address string is a valid IPv4 or IPv6 address.
"""
try:
socket.inet_aton(address)
except socket.error:
self.fail("'%s' is not a valid IP address." % address)
def assert_key_sets(self, required_keys, optional_keys, actual_keys):
"""
Asserts that all required_keys are in actual_keys, and that there
are no keys in actual_keys that aren't required or optional.
"""
present_keys = set(actual_keys)
# Make sure all required keys are present
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
self.assertEqual(
set(),
present_keys - required_keys - optional_keys,
)
def assert_valid_path(self, path, request_path):
"""
Checks the path is valid and already url-decoded.
"""
self.assertIsInstance(path, str)
self.assertEqual(path, request_path)
# Assert that it's already url decoded
self.assertEqual(path, parse.unquote(path))
def assert_valid_address_and_port(self, host):
"""
Asserts the value is a valid (host, port) tuple.
"""
address, port = host
self.assertIsInstance(address, str)
self.assert_is_ip_address(address)
self.assertIsInstance(port, int)
# class ASGIHTTPTestCase(ASGITestCaseBase):
# """
# Test case with helpers for verifying HTTP channel messages
# """
# def assert_valid_http_response_message(self, message, response):
# self.assertTrue(message)
# self.assertTrue(response.startswith(b"HTTP"))
# status_code_bytes = str(message["status"]).encode("ascii")
# self.assertIn(status_code_bytes, response)
# if "content" in message:
# self.assertIn(message["content"], response)
# # Check that headers are in the given order.
# # N.b. HTTP spec only enforces that the order of header values is kept, but
# # the ASGI spec requires that order of all headers is kept. This code
# # checks conformance with the stricter ASGI spec.
# if "headers" in message:
# for name, value in message["headers"]:
# expected_header = factories.header_line(name, value)
# # Daphne or Twisted turn our lower cased header names ('foo-bar') into title
# # case ('Foo-Bar'). So technically we want to to match that the header name is
# # present while ignoring casing, and want to ensure the value is present without
# # altered casing. The approach below does this well enough.
# self.assertIn(expected_header.lower(), response.lower())
# self.assertIn(value.encode("ascii"), response)
# class ASGIWebSocketTestCase(ASGITestCaseBase):
# """
# Test case with helpers for verifying WebSocket channel messages
# """
# def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
# self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
# self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
# self.assertIn(body, response)
# self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
# def assert_websocket_denied(self, response):
# self.assertIn(b"HTTP/1.1 403", response)
# def assert_valid_websocket_connect_message(
# self, channel_message, request_path="/", request_params=None, request_headers=None):
# """
# Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
# """
# self.assertTrue(channel_message)
# self.assert_presence_of_message_keys(
# channel_message.keys(),
# {"reply_channel", "path", "headers", "order"},
# {"scheme", "query_string", "root_path", "client", "server"})
# # == Assertions about required channel_message fields ==
# self.assert_valid_reply_channel(channel_message["reply_channel"])
# self.assert_valid_path(channel_message["path"], request_path)
# order = channel_message["order"]
# self.assertIsInstance(order, int)
# self.assertEqual(order, 0)
# # Ordering of header names is not important, but the order of values for a header
# # name is. To assert whether that order is kept, we transform the request
# # headers and the channel message headers into a set
# # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
# # Note that unlike for HTTP, Daphne never gives out individual header values; instead we
# # get one string per header field with values separated by comma.
# transformed_request_headers = defaultdict(list)
# for name, value in (request_headers or []):
# expected_name = name.lower().strip().encode("ascii")
# expected_value = value.strip().encode("ascii")
# transformed_request_headers[expected_name].append(expected_value)
# final_request_headers = {
# (name, b",".join(value)) for name, value in transformed_request_headers.items()
# }
# # Websockets carry a lot of additional header fields, so instead of verifying that
# # headers look exactly like expected, we just check that the expected header fields
# # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
# # and not tested for.
# assert final_request_headers.issubset(set(channel_message["headers"]))
# # == Assertions about optional channel_message fields ==
# scheme = channel_message.get("scheme")
# if scheme:
# self.assertIsInstance(scheme, six.text_type)
# self.assertIn(scheme, ["ws", "wss"])
# query_string = channel_message.get("query_string")
# if query_string:
# # Assert that query_string is a byte string and still url encoded
# self.assertIsInstance(query_string, six.binary_type)
# self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
# root_path = channel_message.get("root_path")
# if root_path is not None:
# self.assertIsInstance(root_path, six.text_type)
# client = channel_message.get("client")
# if client is not None:
# self.assert_valid_address_and_port(channel_message["client"])
# server = channel_message.get("server")
# if server is not None:
# self.assert_valid_address_and_port(channel_message["server"])

View File

@ -27,18 +27,16 @@ def http_path():
Returns a URL path (not encoded).
"""
return strategies.lists(
_http_path_portion(), min_size=0, max_size=10).map(lambda s: "/" + "/".join(s))
_http_path_portion(),
min_size=0,
max_size=10,
).map(lambda s: "/" + "/".join(s))
def http_body():
"""
Returns random printable ASCII characters. This may be exceeding what HTTP allows,
but seems to not cause an issue so far.
Returns random binary body data.
"""
return strategies.text(alphabet=string.printable, min_size=0, average_size=600, max_size=1500)
def binary_payload():
return strategies.binary(min_size=0, average_size=600, max_size=1500)
@ -59,7 +57,11 @@ def valid_bidi(value):
def _domain_label():
return strategies.text(
alphabet=letters, min_size=1, average_size=6, max_size=63).filter(valid_bidi)
alphabet=letters,
min_size=1,
average_size=6,
max_size=63,
).filter(valid_bidi)
def international_domain_name():
@ -67,12 +69,19 @@ def international_domain_name():
Returns a byte string of a domain name, IDNA-encoded.
"""
return strategies.lists(
_domain_label(), min_size=2, average_size=2).map(lambda s: (".".join(s)).encode("idna"))
_domain_label(),
min_size=2,
average_size=2,
).map(lambda s: (".".join(s)).encode("idna"))
def _query_param():
return strategies.text(alphabet=letters, min_size=1, average_size=10, max_size=255).\
map(lambda s: s.encode("utf8"))
return strategies.text(
alphabet=letters,
min_size=1,
average_size=10,
max_size=255,
).map(lambda s: s.encode("utf8"))
def query_params():
@ -82,8 +91,10 @@ def query_params():
ensures that the total urlencoded query string is not longer than 1500 characters.
"""
return strategies.lists(
strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5).\
filter(lambda x: len(parse.urlencode(x)) < 1500)
strategies.tuples(_query_param(), _query_param()),
min_size=0,
average_size=5,
).filter(lambda x: len(parse.urlencode(x)) < 1500)
def header_name():
@ -94,7 +105,10 @@ def header_name():
and 20 characters long
"""
return strategies.text(
alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30)
alphabet=string.ascii_letters + string.digits + "-",
min_size=1,
max_size=30,
)
def header_value():
@ -106,7 +120,10 @@ def header_value():
"""
return strategies.text(
alphabet=string.ascii_letters + string.digits + string.punctuation + " /t",
min_size=1, average_size=40, max_size=8190).filter(lambda s: len(s.encode("utf8")) < 8190)
min_size=1,
average_size=40,
max_size=8190,
).filter(lambda s: len(s.encode("utf8")) < 8190)
def headers():
@ -118,4 +135,7 @@ def headers():
"""
return strategies.lists(
strategies.tuples(header_name(), header_value()),
min_size=0, average_size=10, max_size=100)
min_size=0,
average_size=10,
max_size=100,
)

267
tests/test_http_request.py Normal file
View File

@ -0,0 +1,267 @@
# coding: utf8
import collections
from urllib import parse
from hypothesis import given, assume, settings, HealthCheck
import http_strategies
from http_base import DaphneTestCase
class TestHTTPRequestSpec(DaphneTestCase):
"""
Tests which try to pour the HTTP request section of the ASGI spec into code.
The heavy lifting is done by the assert_valid_http_request_message function,
the tests mostly serve to wire up hypothesis so that it exercise it's power to find
edge cases.
"""
def assert_valid_http_scope(
self,
scope,
method,
path,
params=None,
headers=None,
scheme=None,
):
"""
Checks that the passed scope is a valid ASGI HTTP scope regarding types
and some urlencoding things.
"""
# Check overall keys
self.assert_key_sets(
required_keys={"type", "http_version", "method", "path", "query_string", "headers"},
optional_keys={"scheme", "root_path", "client", "server"},
actual_keys=scope.keys(),
)
# Check that it is the right type
self.assertEqual(scope["type"], "http")
# Method (uppercased unicode string)
self.assertIsInstance(scope["method"], str)
self.assertEqual(scope["method"], method.upper())
# Path
self.assert_valid_path(scope["path"], path)
# HTTP version
self.assertIn(scope["http_version"], ["1.0", "1.1", "1.2"])
# Scheme
self.assertIn(scope["scheme"], ["http", "https"])
if scheme:
self.assertEqual(scheme, scope["scheme"])
# Query string (byte string and still url encoded)
query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes)
if params:
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
# {name: [value1, value2, ...]} and check if they're equal.
transformed_scope_headers = collections.defaultdict(list)
for name, value in scope["headers"]:
transformed_scope_headers[name].append(value)
transformed_request_headers = collections.defaultdict(list)
for name, value in (headers or []):
expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii")
transformed_request_headers[expected_name].append(expected_value)
for name, value in transformed_request_headers.items():
self.assertIn(name, transformed_scope_headers)
self.assertEqual(value, transformed_scope_headers[name])
# Root path
self.assertIsInstance(scope.get("root_path", ""), str)
# Client and server addresses
client = scope.get("client")
if client is not None:
self.assert_valid_address_and_port(client)
server = scope.get("server")
if server is not None:
self.assert_valid_address_and_port(server)
def assert_valid_http_request_message(
self,
message,
body=None,
):
"""
Asserts that a message is a valid http.request message
"""
# Check overall keys
self.assert_key_sets(
required_keys={"type"},
optional_keys={"body", "more_content"},
actual_keys=message.keys(),
)
# Check that it is the right type
self.assertEqual(message["type"], "http.request")
# If there's a body present, check its type
self.assertIsInstance(message.get("body", b""), bytes)
if body is not None:
self.assertEqual(body, message.get("body", b""))
def test_minimal_request(self):
"""
Smallest viable example. Mostly verifies that our request building works.
"""
scope, messages = self.run_daphne_request("GET", "/")
self.assert_valid_http_scope(scope, "GET", "/")
self.assert_valid_http_request_message(messages[0], body=b"")
@given(
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params()
)
@settings(max_examples=5, deadline=2000)
def test_get_request(self, request_path, request_params):
"""
Tests a typical HTTP GET request, with a path and query parameters
"""
scope, messages = self.run_daphne_request("GET", request_path, params=request_params)
self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
self.assert_valid_http_request_message(messages[0], body=b"")
@given(
request_path=http_strategies.http_path(),
request_body=http_strategies.http_body()
)
@settings(max_examples=5, deadline=2000)
def test_post_request(self, request_path, request_body):
"""
Tests a typical HTTP POST request, with a path and body.
"""
scope, messages = self.run_daphne_request("POST", request_path, data=request_body)
self.assert_valid_http_scope(scope, "POST", request_path)
self.assert_valid_http_request_message(messages[0], body=request_body)
@given(request_headers=http_strategies.headers())
@settings(max_examples=5, deadline=2000)
def test_headers(self, request_headers):
"""
Tests that HTTP header fields are handled as specified
"""
request_path = "/te st-à/"
scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers)
self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers)
self.assert_valid_http_request_message(messages[0], body=b"")
# @given(request_headers=http_strategies.headers())
# def test_duplicate_headers(self, request_headers):
# """
# Tests that duplicate header values are preserved
# """
# assume(len(request_headers) >= 2)
# # Set all header field names to the same value
# header_name = request_headers[0][0]
# duplicated_headers = [(header_name, header[1]) for header in request_headers]
# request_method, request_path = "OPTIONS", "/te st-à/"
# message = message_for_request(request_method, request_path, headers=duplicated_headers)
# self.assert_valid_http_request_message(
# message, request_method, request_path, request_headers=duplicated_headers)
# @given(
# request_method=http_strategies.http_method(),
# request_path=http_strategies.http_path(),
# request_params=http_strategies.query_params(),
# request_headers=http_strategies.headers(),
# request_body=http_strategies.http_body(),
# )
# # This test is slow enough that on Travis, hypothesis sometimes complains.
# @settings(suppress_health_check=[HealthCheck.too_slow])
# def test_kitchen_sink(
# self, request_method, request_path, request_params, request_headers, request_body):
# """
# Throw everything at channels that we dare. The idea is that if a combination
# of method/path/headers/body would break the spec, hypothesis will eventually find it.
# """
# request_headers.append(content_length_header(request_body))
# message = message_for_request(
# request_method, request_path, request_params, request_headers, request_body)
# self.assert_valid_http_request_message(
# message, request_method, request_path, request_params, request_headers, request_body)
# def test_headers_are_lowercased_and_stripped(self):
# request_method, request_path = "GET", "/"
# headers = [("MYCUSTOMHEADER", " foobar ")]
# message = message_for_request(request_method, request_path, headers=headers)
# self.assert_valid_http_request_message(
# message, request_method, request_path, request_headers=headers)
# # Note that Daphne returns a list of tuples here, which is fine, because the spec
# # asks to treat them interchangeably.
# assert message["headers"] == [(b"mycustomheader", b"foobar")]
# @given(daphne_path=http_strategies.http_path())
# def test_root_path_header(self, daphne_path):
# """
# Tests root_path handling.
# """
# request_method, request_path = "GET", "/"
# # Daphne-Root-Path must be URL encoded when submitting as HTTP header field
# headers = [("Daphne-Root-Path", parse.quote(daphne_path.encode("utf8")))]
# message = message_for_request(request_method, request_path, headers=headers)
# # Daphne-Root-Path is not included in the returned 'headers' section. So we expect
# # empty headers.
# expected_headers = []
# self.assert_valid_http_request_message(
# message, request_method, request_path, request_headers=expected_headers)
# # And what we're looking for, root_path being set.
# assert message["root_path"] == daphne_path
# class TestProxyHandling(unittest.TestCase):
# """
# Tests that concern interaction of Daphne with proxies.
# They live in a separate test case, because they're not part of the spec.
# """
# def setUp(self):
# self.channel_layer = ChannelLayer()
# self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
# self.proto = self.factory.buildProtocol(("127.0.0.1", 0))
# self.tr = proto_helpers.StringTransport()
# self.proto.makeConnection(self.tr)
# def test_x_forwarded_for_ignored(self):
# self.proto.dataReceived(
# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
# b"Host: somewhere.com\r\n" +
# b"X-Forwarded-For: 10.1.2.3\r\n" +
# b"X-Forwarded-Port: 80\r\n" +
# b"\r\n"
# )
# # Get the resulting message off of the channel layer
# _, message = self.channel_layer.receive(["http.request"])
# self.assertEqual(message["client"], ["192.168.1.1", 54321])
# def test_x_forwarded_for_parsed(self):
# self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
# self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
# self.proto.dataReceived(
# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
# b"Host: somewhere.com\r\n" +
# b"X-Forwarded-For: 10.1.2.3\r\n" +
# b"X-Forwarded-Port: 80\r\n" +
# b"\r\n"
# )
# # Get the resulting message off of the channel layer
# _, message = self.channel_layer.receive(["http.request"])
# self.assertEqual(message["client"], ["10.1.2.3", 80])
# def test_x_forwarded_for_port_missing(self):
# self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
# self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
# self.proto.dataReceived(
# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
# b"Host: somewhere.com\r\n" +
# b"X-Forwarded-For: 10.1.2.3\r\n" +
# b"\r\n"
# )
# # Get the resulting message off of the channel layer
# _, message = self.channel_layer.receive(["http.request"])
# self.assertEqual(message["client"], ["10.1.2.3", 0])

View File

@ -1,11 +1,9 @@
# coding: utf8
from __future__ import unicode_literals
from unittest import TestCase
import six
from twisted.web.http_headers import Headers
from unittest import TestCase
from ..utils import parse_x_forwarded_for
from daphne.utils import parse_x_forwarded_for
class TestXForwardedForHttpParsing(TestCase):
@ -20,7 +18,7 @@ class TestXForwardedForHttpParsing(TestCase):
})
result = parse_x_forwarded_for(headers)
self.assertEqual(result, ["10.1.2.3", 1234])
self.assertIsInstance(result[0], six.text_type)
self.assertIsInstance(result[0], str)
def test_address_only(self):
headers = Headers({
@ -94,7 +92,7 @@ class TestXForwardedForWsParsing(TestCase):
["1043::a321:0001", 0]
)
def test_multiple_proxys(self):
def test_multiple_proxies(self):
headers = {
b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4",
}