mirror of
https://github.com/django/daphne.git
synced 2024-11-11 02:26:35 +03:00
HTTP protocol tests
This commit is contained in:
parent
0626f39214
commit
b72349d2c1
|
@ -127,12 +127,7 @@ class WebRequest(http.Request):
|
|||
# Remove our HTTP reply channel association
|
||||
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
|
||||
# Resume the producer so we keep getting data, if it's available as a method
|
||||
# 17.1 version
|
||||
if hasattr(self.channel, "_networkProducer"):
|
||||
self.channel._networkProducer.resumeProducing()
|
||||
# 16.x version
|
||||
elif hasattr(self.channel, "resumeProducing"):
|
||||
self.channel.resumeProducing()
|
||||
self.channel._networkProducer.resumeProducing()
|
||||
|
||||
# Boring old HTTP.
|
||||
else:
|
||||
|
|
|
@ -106,6 +106,12 @@ class Server(object):
|
|||
reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications)
|
||||
reactor.run(installSignalHandlers=self.signal_handlers)
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Force-stops the server.
|
||||
"""
|
||||
reactor.stop()
|
||||
|
||||
### Protocol handling
|
||||
|
||||
def add_protocol(self, protocol):
|
||||
|
@ -159,16 +165,20 @@ class Server(object):
|
|||
if application_instance.done():
|
||||
exception = application_instance.exception()
|
||||
if exception:
|
||||
logging.error(
|
||||
"Exception inside application: {}\n{}{}".format(
|
||||
exception,
|
||||
"".join(traceback.format_tb(
|
||||
exception.__traceback__,
|
||||
)),
|
||||
" {}".format(exception),
|
||||
if isinstance(exception, KeyboardInterrupt):
|
||||
# Protocol is asking the server to exit (likely during test)
|
||||
self.stop()
|
||||
else:
|
||||
logging.error(
|
||||
"Exception inside application: {}\n{}{}".format(
|
||||
exception,
|
||||
"".join(traceback.format_tb(
|
||||
exception.__traceback__,
|
||||
)),
|
||||
" {}".format(exception),
|
||||
)
|
||||
)
|
||||
)
|
||||
protocol.handle_exception(exception)
|
||||
protocol.handle_exception(exception)
|
||||
try:
|
||||
del self.application_instances[protocol]
|
||||
except KeyError:
|
||||
|
|
86
daphne/test_utils.py
Normal file
86
daphne/test_utils.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
import msgpack
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
||||
class TestApplication:
|
||||
"""
|
||||
An application that receives one or more messages, sends a response,
|
||||
and then quits the server. For testing.
|
||||
"""
|
||||
|
||||
setup_storage = os.path.join(tempfile.gettempdir(), "setup.testio")
|
||||
result_storage = os.path.join(tempfile.gettempdir(), "result.testio")
|
||||
|
||||
def __init__(self, scope):
|
||||
self.scope = scope
|
||||
self.messages = []
|
||||
|
||||
async def __call__(self, send, receive):
|
||||
# Load setup info
|
||||
setup = self.load_setup()
|
||||
try:
|
||||
for _ in range(setup["receive_messages"]):
|
||||
self.messages.append(await receive())
|
||||
for message in setup["response_messages"]:
|
||||
await send(message)
|
||||
finally:
|
||||
self.save_result()
|
||||
|
||||
@classmethod
|
||||
def save_setup(cls, response_messages, receive_messages=1):
|
||||
"""
|
||||
Stores setup information.
|
||||
"""
|
||||
with open(cls.setup_storage, "wb") as fh:
|
||||
fh.write(msgpack.packb(
|
||||
{
|
||||
"response_messages": response_messages,
|
||||
"receive_messages": receive_messages,
|
||||
},
|
||||
use_bin_type=True,
|
||||
))
|
||||
|
||||
@classmethod
|
||||
def load_setup(cls):
|
||||
"""
|
||||
Returns setup details.
|
||||
"""
|
||||
with open(cls.setup_storage, "rb") as fh:
|
||||
return msgpack.unpackb(fh.read(), encoding="utf-8")
|
||||
|
||||
def save_result(self):
|
||||
"""
|
||||
Saves details of what happened to the result storage.
|
||||
We could use pickle here, but that seems wrong, still, somehow.
|
||||
"""
|
||||
with open(self.result_storage, "wb") as fh:
|
||||
fh.write(msgpack.packb(
|
||||
{
|
||||
"scope": self.scope,
|
||||
"messages": self.messages,
|
||||
},
|
||||
use_bin_type=True,
|
||||
))
|
||||
|
||||
@classmethod
|
||||
def load_result(cls):
|
||||
"""
|
||||
Returns result details.
|
||||
"""
|
||||
with open(cls.result_storage, "rb") as fh:
|
||||
return msgpack.unpackb(fh.read(), encoding="utf-8")
|
||||
|
||||
@classmethod
|
||||
def clear_storage(cls):
|
||||
"""
|
||||
Clears storage files.
|
||||
"""
|
||||
try:
|
||||
os.unlink(cls.setup_storage)
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
os.unlink(cls.result_storage)
|
||||
except OSError:
|
||||
pass
|
|
@ -1,197 +0,0 @@
|
|||
# coding: utf8
|
||||
"""
|
||||
Tests for the HTTP request section of the ASGI spec
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
from six.moves.urllib import parse
|
||||
|
||||
from asgiref.inmemory import ChannelLayer
|
||||
from hypothesis import given, assume, settings, HealthCheck
|
||||
from twisted.test import proto_helpers
|
||||
|
||||
from daphne.http_protocol import HTTPFactory
|
||||
from daphne.tests import testcases, http_strategies
|
||||
from daphne.tests.factories import message_for_request, content_length_header
|
||||
|
||||
|
||||
class TestHTTPRequestSpec(testcases.ASGIHTTPTestCase):
|
||||
"""
|
||||
Tests which try to pour the HTTP request section of the ASGI spec into code.
|
||||
The heavy lifting is done by the assert_valid_http_request_message function,
|
||||
the tests mostly serve to wire up hypothesis so that it exercise it's power to find
|
||||
edge cases.
|
||||
"""
|
||||
|
||||
def test_minimal_request(self):
|
||||
"""
|
||||
Smallest viable example. Mostly verifies that our request building works.
|
||||
"""
|
||||
request_method, request_path = "GET", "/"
|
||||
message = message_for_request(request_method, request_path)
|
||||
|
||||
self.assert_valid_http_request_message(message, request_method, request_path)
|
||||
|
||||
@given(
|
||||
request_path=http_strategies.http_path(),
|
||||
request_params=http_strategies.query_params()
|
||||
)
|
||||
def test_get_request(self, request_path, request_params):
|
||||
"""
|
||||
Tests a typical HTTP GET request, with a path and query parameters
|
||||
"""
|
||||
request_method = "GET"
|
||||
message = message_for_request(request_method, request_path, request_params)
|
||||
|
||||
self.assert_valid_http_request_message(
|
||||
message, request_method, request_path, request_params=request_params)
|
||||
|
||||
@given(
|
||||
request_path=http_strategies.http_path(),
|
||||
request_body=http_strategies.http_body()
|
||||
)
|
||||
def test_post_request(self, request_path, request_body):
|
||||
"""
|
||||
Tests a typical POST request, submitting some data in a body.
|
||||
"""
|
||||
request_method = "POST"
|
||||
headers = [content_length_header(request_body)]
|
||||
message = message_for_request(
|
||||
request_method, request_path, headers=headers, body=request_body)
|
||||
|
||||
self.assert_valid_http_request_message(
|
||||
message, request_method, request_path,
|
||||
request_headers=headers, request_body=request_body)
|
||||
|
||||
@given(request_headers=http_strategies.headers())
|
||||
def test_headers(self, request_headers):
|
||||
"""
|
||||
Tests that HTTP header fields are handled as specified
|
||||
"""
|
||||
request_method, request_path = "OPTIONS", "/te st-à/"
|
||||
message = message_for_request(request_method, request_path, headers=request_headers)
|
||||
|
||||
self.assert_valid_http_request_message(
|
||||
message, request_method, request_path, request_headers=request_headers)
|
||||
|
||||
@given(request_headers=http_strategies.headers())
|
||||
def test_duplicate_headers(self, request_headers):
|
||||
"""
|
||||
Tests that duplicate header values are preserved
|
||||
"""
|
||||
assume(len(request_headers) >= 2)
|
||||
# Set all header field names to the same value
|
||||
header_name = request_headers[0][0]
|
||||
duplicated_headers = [(header_name, header[1]) for header in request_headers]
|
||||
|
||||
request_method, request_path = "OPTIONS", "/te st-à/"
|
||||
message = message_for_request(request_method, request_path, headers=duplicated_headers)
|
||||
|
||||
self.assert_valid_http_request_message(
|
||||
message, request_method, request_path, request_headers=duplicated_headers)
|
||||
|
||||
@given(
|
||||
request_method=http_strategies.http_method(),
|
||||
request_path=http_strategies.http_path(),
|
||||
request_params=http_strategies.query_params(),
|
||||
request_headers=http_strategies.headers(),
|
||||
request_body=http_strategies.http_body(),
|
||||
)
|
||||
# This test is slow enough that on Travis, hypothesis sometimes complains.
|
||||
@settings(suppress_health_check=[HealthCheck.too_slow])
|
||||
def test_kitchen_sink(
|
||||
self, request_method, request_path, request_params, request_headers, request_body):
|
||||
"""
|
||||
Throw everything at channels that we dare. The idea is that if a combination
|
||||
of method/path/headers/body would break the spec, hypothesis will eventually find it.
|
||||
"""
|
||||
request_headers.append(content_length_header(request_body))
|
||||
message = message_for_request(
|
||||
request_method, request_path, request_params, request_headers, request_body)
|
||||
|
||||
self.assert_valid_http_request_message(
|
||||
message, request_method, request_path, request_params, request_headers, request_body)
|
||||
|
||||
def test_headers_are_lowercased_and_stripped(self):
|
||||
request_method, request_path = "GET", "/"
|
||||
headers = [("MYCUSTOMHEADER", " foobar ")]
|
||||
message = message_for_request(request_method, request_path, headers=headers)
|
||||
|
||||
self.assert_valid_http_request_message(
|
||||
message, request_method, request_path, request_headers=headers)
|
||||
# Note that Daphne returns a list of tuples here, which is fine, because the spec
|
||||
# asks to treat them interchangeably.
|
||||
assert message["headers"] == [(b"mycustomheader", b"foobar")]
|
||||
|
||||
@given(daphne_path=http_strategies.http_path())
|
||||
def test_root_path_header(self, daphne_path):
|
||||
"""
|
||||
Tests root_path handling.
|
||||
"""
|
||||
request_method, request_path = "GET", "/"
|
||||
# Daphne-Root-Path must be URL encoded when submitting as HTTP header field
|
||||
headers = [("Daphne-Root-Path", parse.quote(daphne_path.encode("utf8")))]
|
||||
message = message_for_request(request_method, request_path, headers=headers)
|
||||
|
||||
# Daphne-Root-Path is not included in the returned 'headers' section. So we expect
|
||||
# empty headers.
|
||||
expected_headers = []
|
||||
self.assert_valid_http_request_message(
|
||||
message, request_method, request_path, request_headers=expected_headers)
|
||||
# And what we're looking for, root_path being set.
|
||||
assert message["root_path"] == daphne_path
|
||||
|
||||
|
||||
class TestProxyHandling(unittest.TestCase):
|
||||
"""
|
||||
Tests that concern interaction of Daphne with proxies.
|
||||
|
||||
They live in a separate test case, because they're not part of the spec.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.channel_layer = ChannelLayer()
|
||||
self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
|
||||
self.proto = self.factory.buildProtocol(("127.0.0.1", 0))
|
||||
self.tr = proto_helpers.StringTransport()
|
||||
self.proto.makeConnection(self.tr)
|
||||
|
||||
def test_x_forwarded_for_ignored(self):
|
||||
self.proto.dataReceived(
|
||||
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
|
||||
b"Host: somewhere.com\r\n" +
|
||||
b"X-Forwarded-For: 10.1.2.3\r\n" +
|
||||
b"X-Forwarded-Port: 80\r\n" +
|
||||
b"\r\n"
|
||||
)
|
||||
# Get the resulting message off of the channel layer
|
||||
_, message = self.channel_layer.receive(["http.request"])
|
||||
self.assertEqual(message["client"], ["192.168.1.1", 54321])
|
||||
|
||||
def test_x_forwarded_for_parsed(self):
|
||||
self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
|
||||
self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
|
||||
self.proto.dataReceived(
|
||||
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
|
||||
b"Host: somewhere.com\r\n" +
|
||||
b"X-Forwarded-For: 10.1.2.3\r\n" +
|
||||
b"X-Forwarded-Port: 80\r\n" +
|
||||
b"\r\n"
|
||||
)
|
||||
# Get the resulting message off of the channel layer
|
||||
_, message = self.channel_layer.receive(["http.request"])
|
||||
self.assertEqual(message["client"], ["10.1.2.3", 80])
|
||||
|
||||
def test_x_forwarded_for_port_missing(self):
|
||||
self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
|
||||
self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
|
||||
self.proto.dataReceived(
|
||||
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
|
||||
b"Host: somewhere.com\r\n" +
|
||||
b"X-Forwarded-For: 10.1.2.3\r\n" +
|
||||
b"\r\n"
|
||||
)
|
||||
# Get the resulting message off of the channel layer
|
||||
_, message = self.channel_layer.receive(["http.request"])
|
||||
self.assertEqual(message["client"], ["10.1.2.3", 0])
|
|
@ -1,238 +0,0 @@
|
|||
"""
|
||||
Contains a test case class to allow verifying ASGI messages
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from collections import defaultdict
|
||||
from urllib import parse
|
||||
import socket
|
||||
import unittest
|
||||
|
||||
from . import factories
|
||||
|
||||
|
||||
class ASGITestCaseBase(unittest.TestCase):
|
||||
"""
|
||||
Base class for our test classes which contains shared method.
|
||||
"""
|
||||
|
||||
def assert_is_ip_address(self, address):
|
||||
"""
|
||||
Tests whether a given address string is a valid IPv4 or IPv6 address.
|
||||
"""
|
||||
try:
|
||||
socket.inet_aton(address)
|
||||
except socket.error:
|
||||
self.fail("'%s' is not a valid IP address." % address)
|
||||
|
||||
def assert_presence_of_message_keys(self, keys, required_keys, optional_keys):
|
||||
present_keys = set(keys)
|
||||
self.assertTrue(required_keys <= present_keys)
|
||||
# Assert that no other keys are present
|
||||
self.assertEqual(set(), present_keys - required_keys - optional_keys)
|
||||
|
||||
def assert_valid_reply_channel(self, reply_channel):
|
||||
self.assertIsInstance(reply_channel, str)
|
||||
# The reply channel is decided by the server.
|
||||
self.assertTrue(reply_channel.startswith("test!"))
|
||||
|
||||
def assert_valid_path(self, path, request_path):
|
||||
self.assertIsInstance(path, str)
|
||||
self.assertEqual(path, request_path)
|
||||
# Assert that it's already url decoded
|
||||
self.assertEqual(path, parse.unquote(path))
|
||||
|
||||
def assert_valid_address_and_port(self, host):
|
||||
address, port = host
|
||||
self.assertIsInstance(address, str)
|
||||
self.assert_is_ip_address(address)
|
||||
self.assertIsInstance(port, int)
|
||||
|
||||
|
||||
class ASGIHTTPTestCase(ASGITestCaseBase):
|
||||
"""
|
||||
Test case with helpers for verifying HTTP channel messages
|
||||
"""
|
||||
|
||||
def assert_valid_http_request_message(
|
||||
self, channel_message, request_method, request_path,
|
||||
request_params=None, request_headers=None, request_body=None):
|
||||
"""
|
||||
Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
|
||||
"""
|
||||
|
||||
self.assertTrue(channel_message)
|
||||
|
||||
self.assert_presence_of_message_keys(
|
||||
channel_message.keys(),
|
||||
{"reply_channel", "http_version", "method", "path", "query_string", "headers"},
|
||||
{"scheme", "root_path", "body", "body_channel", "client", "server"})
|
||||
|
||||
# == Assertions about required channel_message fields ==
|
||||
self.assert_valid_reply_channel(channel_message["reply_channel"])
|
||||
self.assert_valid_path(channel_message["path"], request_path)
|
||||
|
||||
http_version = channel_message["http_version"]
|
||||
self.assertIsInstance(http_version, str)
|
||||
self.assertIn(http_version, ["1.0", "1.1", "1.2"])
|
||||
|
||||
method = channel_message["method"]
|
||||
self.assertIsInstance(method, str)
|
||||
self.assertTrue(method.isupper())
|
||||
self.assertEqual(channel_message["method"], request_method)
|
||||
|
||||
query_string = channel_message["query_string"]
|
||||
# Assert that query_string is a byte string and still url encoded
|
||||
self.assertIsInstance(query_string, bytes)
|
||||
self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
|
||||
|
||||
# Ordering of header names is not important, but the order of values for a header
|
||||
# name is. To assert whether that order is kept, we transform both the request
|
||||
# headers and the channel message headers into a dictionary
|
||||
# {name: [value1, value2, ...]} and check if they're equal.
|
||||
transformed_message_headers = defaultdict(list)
|
||||
for name, value in channel_message["headers"]:
|
||||
transformed_message_headers[name].append(value)
|
||||
|
||||
transformed_request_headers = defaultdict(list)
|
||||
for name, value in (request_headers or []):
|
||||
expected_name = name.lower().strip().encode("ascii")
|
||||
expected_value = value.strip().encode("ascii")
|
||||
transformed_request_headers[expected_name].append(expected_value)
|
||||
|
||||
self.assertEqual(transformed_message_headers, transformed_request_headers)
|
||||
|
||||
# == Assertions about optional channel_message fields ==
|
||||
|
||||
scheme = channel_message.get("scheme")
|
||||
if scheme is not None:
|
||||
self.assertIsInstance(scheme, str)
|
||||
self.assertTrue(scheme) # May not be empty
|
||||
|
||||
root_path = channel_message.get("root_path")
|
||||
if root_path is not None:
|
||||
self.assertIsInstance(root_path, str)
|
||||
|
||||
body = channel_message.get("body")
|
||||
# Ensure we test for presence of 'body' if a request body was given
|
||||
if request_body is not None or body is not None:
|
||||
self.assertIsInstance(body, str)
|
||||
self.assertEqual(body, (request_body or "").encode("ascii"))
|
||||
|
||||
body_channel = channel_message.get("body_channel")
|
||||
if body_channel is not None:
|
||||
self.assertIsInstance(body_channel, str)
|
||||
self.assertIn("?", body_channel)
|
||||
|
||||
client = channel_message.get("client")
|
||||
if client is not None:
|
||||
self.assert_valid_address_and_port(channel_message["client"])
|
||||
|
||||
server = channel_message.get("server")
|
||||
if server is not None:
|
||||
self.assert_valid_address_and_port(channel_message["server"])
|
||||
|
||||
def assert_valid_http_response_message(self, message, response):
|
||||
self.assertTrue(message)
|
||||
self.assertTrue(response.startswith(b"HTTP"))
|
||||
|
||||
status_code_bytes = str(message["status"]).encode("ascii")
|
||||
self.assertIn(status_code_bytes, response)
|
||||
|
||||
if "content" in message:
|
||||
self.assertIn(message["content"], response)
|
||||
|
||||
# Check that headers are in the given order.
|
||||
# N.b. HTTP spec only enforces that the order of header values is kept, but
|
||||
# the ASGI spec requires that order of all headers is kept. This code
|
||||
# checks conformance with the stricter ASGI spec.
|
||||
if "headers" in message:
|
||||
for name, value in message["headers"]:
|
||||
expected_header = factories.header_line(name, value)
|
||||
# Daphne or Twisted turn our lower cased header names ('foo-bar') into title
|
||||
# case ('Foo-Bar'). So technically we want to to match that the header name is
|
||||
# present while ignoring casing, and want to ensure the value is present without
|
||||
# altered casing. The approach below does this well enough.
|
||||
self.assertIn(expected_header.lower(), response.lower())
|
||||
self.assertIn(value.encode("ascii"), response)
|
||||
|
||||
|
||||
class ASGIWebSocketTestCase(ASGITestCaseBase):
|
||||
"""
|
||||
Test case with helpers for verifying WebSocket channel messages
|
||||
"""
|
||||
|
||||
def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
|
||||
self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
|
||||
self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
|
||||
self.assertIn(body, response)
|
||||
self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
|
||||
|
||||
def assert_websocket_denied(self, response):
|
||||
self.assertIn(b"HTTP/1.1 403", response)
|
||||
|
||||
def assert_valid_websocket_connect_message(
|
||||
self, channel_message, request_path="/", request_params=None, request_headers=None):
|
||||
"""
|
||||
Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
|
||||
"""
|
||||
|
||||
self.assertTrue(channel_message)
|
||||
|
||||
self.assert_presence_of_message_keys(
|
||||
channel_message.keys(),
|
||||
{"reply_channel", "path", "headers", "order"},
|
||||
{"scheme", "query_string", "root_path", "client", "server"})
|
||||
|
||||
# == Assertions about required channel_message fields ==
|
||||
self.assert_valid_reply_channel(channel_message["reply_channel"])
|
||||
self.assert_valid_path(channel_message["path"], request_path)
|
||||
|
||||
order = channel_message["order"]
|
||||
self.assertIsInstance(order, int)
|
||||
self.assertEqual(order, 0)
|
||||
|
||||
# Ordering of header names is not important, but the order of values for a header
|
||||
# name is. To assert whether that order is kept, we transform the request
|
||||
# headers and the channel message headers into a set
|
||||
# {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
|
||||
# Note that unlike for HTTP, Daphne never gives out individual header values; instead we
|
||||
# get one string per header field with values separated by comma.
|
||||
transformed_request_headers = defaultdict(list)
|
||||
for name, value in (request_headers or []):
|
||||
expected_name = name.lower().strip().encode("ascii")
|
||||
expected_value = value.strip().encode("ascii")
|
||||
transformed_request_headers[expected_name].append(expected_value)
|
||||
final_request_headers = {
|
||||
(name, b",".join(value)) for name, value in transformed_request_headers.items()
|
||||
}
|
||||
|
||||
# Websockets carry a lot of additional header fields, so instead of verifying that
|
||||
# headers look exactly like expected, we just check that the expected header fields
|
||||
# and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
|
||||
# and not tested for.
|
||||
assert final_request_headers.issubset(set(channel_message["headers"]))
|
||||
|
||||
# == Assertions about optional channel_message fields ==
|
||||
scheme = channel_message.get("scheme")
|
||||
if scheme:
|
||||
self.assertIsInstance(scheme, six.text_type)
|
||||
self.assertIn(scheme, ["ws", "wss"])
|
||||
|
||||
query_string = channel_message.get("query_string")
|
||||
if query_string:
|
||||
# Assert that query_string is a byte string and still url encoded
|
||||
self.assertIsInstance(query_string, six.binary_type)
|
||||
self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
|
||||
|
||||
root_path = channel_message.get("root_path")
|
||||
if root_path is not None:
|
||||
self.assertIsInstance(root_path, six.text_type)
|
||||
|
||||
client = channel_message.get("client")
|
||||
if client is not None:
|
||||
self.assert_valid_address_and_port(channel_message["client"])
|
||||
|
||||
server = channel_message.get("server")
|
||||
if server is not None:
|
||||
self.assert_valid_address_and_port(channel_message["server"])
|
|
@ -3,12 +3,3 @@ universal=1
|
|||
|
||||
[tool:pytest]
|
||||
addopts = tests/
|
||||
|
||||
[yapf]
|
||||
based_on_style = pep8
|
||||
column_limit = 120
|
||||
join_multiple_lines = false
|
||||
split_arguments_when_comma_terminated = true
|
||||
split_before_expression_after_opening_paren = true
|
||||
split_before_first_argument = true
|
||||
split_penalty_after_opening_bracket = -10
|
||||
|
|
248
tests/http_base.py
Normal file
248
tests/http_base.py
Normal file
|
@ -0,0 +1,248 @@
|
|||
from urllib import parse
|
||||
import requests
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from daphne.test_utils import TestApplication
|
||||
|
||||
|
||||
class DaphneTestCase(unittest.TestCase):
|
||||
"""
|
||||
Base class for Daphne integration test cases.
|
||||
|
||||
Boots up a copy of Daphne on a test port and sends it a request, and
|
||||
retrieves the response. Uses a custom ASGI application and temporary files
|
||||
to store/retrieve the request/response messages.
|
||||
"""
|
||||
|
||||
def port_in_use(self, port):
|
||||
"""
|
||||
Tests if a port is in use on the local machine.
|
||||
"""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
s.bind(("127.0.0.1", port))
|
||||
except socket.error as e:
|
||||
if e.errno in [13, 98]:
|
||||
return True
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
return False
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
def run_daphne(self, method, path, params, data, responses, headers=None, timeout=1):
|
||||
"""
|
||||
Runs Daphne with the given request callback (given the base URL)
|
||||
and response messages.
|
||||
"""
|
||||
# Store setup info
|
||||
TestApplication.clear_storage()
|
||||
TestApplication.save_setup(
|
||||
response_messages=responses,
|
||||
)
|
||||
# Find a free port
|
||||
for i in range(11200, 11300):
|
||||
if not self.port_in_use(i):
|
||||
port = i
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Cannot find a free port to test on")
|
||||
# Launch daphne on that port
|
||||
process = subprocess.Popen(["daphne", "-p", str(port), "daphne.test_utils:TestApplication"])
|
||||
try:
|
||||
for _ in range(100):
|
||||
time.sleep(0.1)
|
||||
if self.port_in_use(port):
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Daphne never came up.")
|
||||
# Send it the request
|
||||
url = "http://127.0.0.1:%i%s" % (port, path)
|
||||
response = requests.request(method, url, params=params, data=data, headers=headers, timeout=timeout)
|
||||
finally:
|
||||
# Shut down daphne
|
||||
process.terminate()
|
||||
# Load the information
|
||||
inner_result = TestApplication.load_result()
|
||||
# Return the inner result and the response
|
||||
return inner_result, response
|
||||
|
||||
def run_daphne_request(self, method, path, params=None, data=None, headers=None):
|
||||
"""
|
||||
Convenience method for just testing request handling.
|
||||
Returns (scope, messages)
|
||||
"""
|
||||
if headers is not None:
|
||||
headers = dict(headers)
|
||||
inner_result, _ = self.run_daphne(
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
data=data,
|
||||
headers=headers,
|
||||
responses=[{"type": "http.response", "status": 200, "content": b"OK"}],
|
||||
)
|
||||
return inner_result["scope"], inner_result["messages"]
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Ensures any storage files are cleared.
|
||||
"""
|
||||
TestApplication.clear_storage()
|
||||
|
||||
def assert_is_ip_address(self, address):
|
||||
"""
|
||||
Tests whether a given address string is a valid IPv4 or IPv6 address.
|
||||
"""
|
||||
try:
|
||||
socket.inet_aton(address)
|
||||
except socket.error:
|
||||
self.fail("'%s' is not a valid IP address." % address)
|
||||
|
||||
def assert_key_sets(self, required_keys, optional_keys, actual_keys):
|
||||
"""
|
||||
Asserts that all required_keys are in actual_keys, and that there
|
||||
are no keys in actual_keys that aren't required or optional.
|
||||
"""
|
||||
present_keys = set(actual_keys)
|
||||
# Make sure all required keys are present
|
||||
self.assertTrue(required_keys <= present_keys)
|
||||
# Assert that no other keys are present
|
||||
self.assertEqual(
|
||||
set(),
|
||||
present_keys - required_keys - optional_keys,
|
||||
)
|
||||
|
||||
def assert_valid_path(self, path, request_path):
|
||||
"""
|
||||
Checks the path is valid and already url-decoded.
|
||||
"""
|
||||
self.assertIsInstance(path, str)
|
||||
self.assertEqual(path, request_path)
|
||||
# Assert that it's already url decoded
|
||||
self.assertEqual(path, parse.unquote(path))
|
||||
|
||||
def assert_valid_address_and_port(self, host):
|
||||
"""
|
||||
Asserts the value is a valid (host, port) tuple.
|
||||
"""
|
||||
address, port = host
|
||||
self.assertIsInstance(address, str)
|
||||
self.assert_is_ip_address(address)
|
||||
self.assertIsInstance(port, int)
|
||||
|
||||
|
||||
# class ASGIHTTPTestCase(ASGITestCaseBase):
|
||||
# """
|
||||
# Test case with helpers for verifying HTTP channel messages
|
||||
# """
|
||||
|
||||
|
||||
# def assert_valid_http_response_message(self, message, response):
|
||||
# self.assertTrue(message)
|
||||
# self.assertTrue(response.startswith(b"HTTP"))
|
||||
|
||||
# status_code_bytes = str(message["status"]).encode("ascii")
|
||||
# self.assertIn(status_code_bytes, response)
|
||||
|
||||
# if "content" in message:
|
||||
# self.assertIn(message["content"], response)
|
||||
|
||||
# # Check that headers are in the given order.
|
||||
# # N.b. HTTP spec only enforces that the order of header values is kept, but
|
||||
# # the ASGI spec requires that order of all headers is kept. This code
|
||||
# # checks conformance with the stricter ASGI spec.
|
||||
# if "headers" in message:
|
||||
# for name, value in message["headers"]:
|
||||
# expected_header = factories.header_line(name, value)
|
||||
# # Daphne or Twisted turn our lower cased header names ('foo-bar') into title
|
||||
# # case ('Foo-Bar'). So technically we want to to match that the header name is
|
||||
# # present while ignoring casing, and want to ensure the value is present without
|
||||
# # altered casing. The approach below does this well enough.
|
||||
# self.assertIn(expected_header.lower(), response.lower())
|
||||
# self.assertIn(value.encode("ascii"), response)
|
||||
|
||||
|
||||
# class ASGIWebSocketTestCase(ASGITestCaseBase):
|
||||
# """
|
||||
# Test case with helpers for verifying WebSocket channel messages
|
||||
# """
|
||||
|
||||
# def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
|
||||
# self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
|
||||
# self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
|
||||
# self.assertIn(body, response)
|
||||
# self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
|
||||
|
||||
# def assert_websocket_denied(self, response):
|
||||
# self.assertIn(b"HTTP/1.1 403", response)
|
||||
|
||||
# def assert_valid_websocket_connect_message(
|
||||
# self, channel_message, request_path="/", request_params=None, request_headers=None):
|
||||
# """
|
||||
# Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
|
||||
# """
|
||||
|
||||
# self.assertTrue(channel_message)
|
||||
|
||||
# self.assert_presence_of_message_keys(
|
||||
# channel_message.keys(),
|
||||
# {"reply_channel", "path", "headers", "order"},
|
||||
# {"scheme", "query_string", "root_path", "client", "server"})
|
||||
|
||||
# # == Assertions about required channel_message fields ==
|
||||
# self.assert_valid_reply_channel(channel_message["reply_channel"])
|
||||
# self.assert_valid_path(channel_message["path"], request_path)
|
||||
|
||||
# order = channel_message["order"]
|
||||
# self.assertIsInstance(order, int)
|
||||
# self.assertEqual(order, 0)
|
||||
|
||||
# # Ordering of header names is not important, but the order of values for a header
|
||||
# # name is. To assert whether that order is kept, we transform the request
|
||||
# # headers and the channel message headers into a set
|
||||
# # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
|
||||
# # Note that unlike for HTTP, Daphne never gives out individual header values; instead we
|
||||
# # get one string per header field with values separated by comma.
|
||||
# transformed_request_headers = defaultdict(list)
|
||||
# for name, value in (request_headers or []):
|
||||
# expected_name = name.lower().strip().encode("ascii")
|
||||
# expected_value = value.strip().encode("ascii")
|
||||
# transformed_request_headers[expected_name].append(expected_value)
|
||||
# final_request_headers = {
|
||||
# (name, b",".join(value)) for name, value in transformed_request_headers.items()
|
||||
# }
|
||||
|
||||
# # Websockets carry a lot of additional header fields, so instead of verifying that
|
||||
# # headers look exactly like expected, we just check that the expected header fields
|
||||
# # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
|
||||
# # and not tested for.
|
||||
# assert final_request_headers.issubset(set(channel_message["headers"]))
|
||||
|
||||
# # == Assertions about optional channel_message fields ==
|
||||
# scheme = channel_message.get("scheme")
|
||||
# if scheme:
|
||||
# self.assertIsInstance(scheme, six.text_type)
|
||||
# self.assertIn(scheme, ["ws", "wss"])
|
||||
|
||||
# query_string = channel_message.get("query_string")
|
||||
# if query_string:
|
||||
# # Assert that query_string is a byte string and still url encoded
|
||||
# self.assertIsInstance(query_string, six.binary_type)
|
||||
# self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
|
||||
|
||||
# root_path = channel_message.get("root_path")
|
||||
# if root_path is not None:
|
||||
# self.assertIsInstance(root_path, six.text_type)
|
||||
|
||||
# client = channel_message.get("client")
|
||||
# if client is not None:
|
||||
# self.assert_valid_address_and_port(channel_message["client"])
|
||||
|
||||
# server = channel_message.get("server")
|
||||
# if server is not None:
|
||||
# self.assert_valid_address_and_port(channel_message["server"])
|
|
@ -27,18 +27,16 @@ def http_path():
|
|||
Returns a URL path (not encoded).
|
||||
"""
|
||||
return strategies.lists(
|
||||
_http_path_portion(), min_size=0, max_size=10).map(lambda s: "/" + "/".join(s))
|
||||
_http_path_portion(),
|
||||
min_size=0,
|
||||
max_size=10,
|
||||
).map(lambda s: "/" + "/".join(s))
|
||||
|
||||
|
||||
def http_body():
|
||||
"""
|
||||
Returns random printable ASCII characters. This may be exceeding what HTTP allows,
|
||||
but seems to not cause an issue so far.
|
||||
Returns random binary body data.
|
||||
"""
|
||||
return strategies.text(alphabet=string.printable, min_size=0, average_size=600, max_size=1500)
|
||||
|
||||
|
||||
def binary_payload():
|
||||
return strategies.binary(min_size=0, average_size=600, max_size=1500)
|
||||
|
||||
|
||||
|
@ -59,7 +57,11 @@ def valid_bidi(value):
|
|||
|
||||
def _domain_label():
|
||||
return strategies.text(
|
||||
alphabet=letters, min_size=1, average_size=6, max_size=63).filter(valid_bidi)
|
||||
alphabet=letters,
|
||||
min_size=1,
|
||||
average_size=6,
|
||||
max_size=63,
|
||||
).filter(valid_bidi)
|
||||
|
||||
|
||||
def international_domain_name():
|
||||
|
@ -67,12 +69,19 @@ def international_domain_name():
|
|||
Returns a byte string of a domain name, IDNA-encoded.
|
||||
"""
|
||||
return strategies.lists(
|
||||
_domain_label(), min_size=2, average_size=2).map(lambda s: (".".join(s)).encode("idna"))
|
||||
_domain_label(),
|
||||
min_size=2,
|
||||
average_size=2,
|
||||
).map(lambda s: (".".join(s)).encode("idna"))
|
||||
|
||||
|
||||
def _query_param():
|
||||
return strategies.text(alphabet=letters, min_size=1, average_size=10, max_size=255).\
|
||||
map(lambda s: s.encode("utf8"))
|
||||
return strategies.text(
|
||||
alphabet=letters,
|
||||
min_size=1,
|
||||
average_size=10,
|
||||
max_size=255,
|
||||
).map(lambda s: s.encode("utf8"))
|
||||
|
||||
|
||||
def query_params():
|
||||
|
@ -82,8 +91,10 @@ def query_params():
|
|||
ensures that the total urlencoded query string is not longer than 1500 characters.
|
||||
"""
|
||||
return strategies.lists(
|
||||
strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5).\
|
||||
filter(lambda x: len(parse.urlencode(x)) < 1500)
|
||||
strategies.tuples(_query_param(), _query_param()),
|
||||
min_size=0,
|
||||
average_size=5,
|
||||
).filter(lambda x: len(parse.urlencode(x)) < 1500)
|
||||
|
||||
|
||||
def header_name():
|
||||
|
@ -94,7 +105,10 @@ def header_name():
|
|||
and 20 characters long
|
||||
"""
|
||||
return strategies.text(
|
||||
alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30)
|
||||
alphabet=string.ascii_letters + string.digits + "-",
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
)
|
||||
|
||||
|
||||
def header_value():
|
||||
|
@ -106,7 +120,10 @@ def header_value():
|
|||
"""
|
||||
return strategies.text(
|
||||
alphabet=string.ascii_letters + string.digits + string.punctuation + " /t",
|
||||
min_size=1, average_size=40, max_size=8190).filter(lambda s: len(s.encode("utf8")) < 8190)
|
||||
min_size=1,
|
||||
average_size=40,
|
||||
max_size=8190,
|
||||
).filter(lambda s: len(s.encode("utf8")) < 8190)
|
||||
|
||||
|
||||
def headers():
|
||||
|
@ -118,4 +135,7 @@ def headers():
|
|||
"""
|
||||
return strategies.lists(
|
||||
strategies.tuples(header_name(), header_value()),
|
||||
min_size=0, average_size=10, max_size=100)
|
||||
min_size=0,
|
||||
average_size=10,
|
||||
max_size=100,
|
||||
)
|
267
tests/test_http_request.py
Normal file
267
tests/test_http_request.py
Normal file
|
@ -0,0 +1,267 @@
|
|||
# coding: utf8
|
||||
|
||||
import collections
|
||||
from urllib import parse
|
||||
|
||||
from hypothesis import given, assume, settings, HealthCheck
|
||||
|
||||
import http_strategies
|
||||
from http_base import DaphneTestCase
|
||||
|
||||
|
||||
class TestHTTPRequestSpec(DaphneTestCase):
|
||||
"""
|
||||
Tests which try to pour the HTTP request section of the ASGI spec into code.
|
||||
The heavy lifting is done by the assert_valid_http_request_message function,
|
||||
the tests mostly serve to wire up hypothesis so that it exercise it's power to find
|
||||
edge cases.
|
||||
"""
|
||||
|
||||
def assert_valid_http_scope(
|
||||
self,
|
||||
scope,
|
||||
method,
|
||||
path,
|
||||
params=None,
|
||||
headers=None,
|
||||
scheme=None,
|
||||
):
|
||||
"""
|
||||
Checks that the passed scope is a valid ASGI HTTP scope regarding types
|
||||
and some urlencoding things.
|
||||
"""
|
||||
# Check overall keys
|
||||
self.assert_key_sets(
|
||||
required_keys={"type", "http_version", "method", "path", "query_string", "headers"},
|
||||
optional_keys={"scheme", "root_path", "client", "server"},
|
||||
actual_keys=scope.keys(),
|
||||
)
|
||||
# Check that it is the right type
|
||||
self.assertEqual(scope["type"], "http")
|
||||
# Method (uppercased unicode string)
|
||||
self.assertIsInstance(scope["method"], str)
|
||||
self.assertEqual(scope["method"], method.upper())
|
||||
# Path
|
||||
self.assert_valid_path(scope["path"], path)
|
||||
# HTTP version
|
||||
self.assertIn(scope["http_version"], ["1.0", "1.1", "1.2"])
|
||||
# Scheme
|
||||
self.assertIn(scope["scheme"], ["http", "https"])
|
||||
if scheme:
|
||||
self.assertEqual(scheme, scope["scheme"])
|
||||
# Query string (byte string and still url encoded)
|
||||
query_string = scope["query_string"]
|
||||
self.assertIsInstance(query_string, bytes)
|
||||
if params:
|
||||
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
|
||||
# Ordering of header names is not important, but the order of values for a header
|
||||
# name is. To assert whether that order is kept, we transform both the request
|
||||
# headers and the channel message headers into a dictionary
|
||||
# {name: [value1, value2, ...]} and check if they're equal.
|
||||
transformed_scope_headers = collections.defaultdict(list)
|
||||
for name, value in scope["headers"]:
|
||||
transformed_scope_headers[name].append(value)
|
||||
transformed_request_headers = collections.defaultdict(list)
|
||||
for name, value in (headers or []):
|
||||
expected_name = name.lower().strip().encode("ascii")
|
||||
expected_value = value.strip().encode("ascii")
|
||||
transformed_request_headers[expected_name].append(expected_value)
|
||||
for name, value in transformed_request_headers.items():
|
||||
self.assertIn(name, transformed_scope_headers)
|
||||
self.assertEqual(value, transformed_scope_headers[name])
|
||||
# Root path
|
||||
self.assertIsInstance(scope.get("root_path", ""), str)
|
||||
# Client and server addresses
|
||||
client = scope.get("client")
|
||||
if client is not None:
|
||||
self.assert_valid_address_and_port(client)
|
||||
server = scope.get("server")
|
||||
if server is not None:
|
||||
self.assert_valid_address_and_port(server)
|
||||
|
||||
def assert_valid_http_request_message(
|
||||
self,
|
||||
message,
|
||||
body=None,
|
||||
):
|
||||
"""
|
||||
Asserts that a message is a valid http.request message
|
||||
"""
|
||||
# Check overall keys
|
||||
self.assert_key_sets(
|
||||
required_keys={"type"},
|
||||
optional_keys={"body", "more_content"},
|
||||
actual_keys=message.keys(),
|
||||
)
|
||||
# Check that it is the right type
|
||||
self.assertEqual(message["type"], "http.request")
|
||||
# If there's a body present, check its type
|
||||
self.assertIsInstance(message.get("body", b""), bytes)
|
||||
if body is not None:
|
||||
self.assertEqual(body, message.get("body", b""))
|
||||
|
||||
def test_minimal_request(self):
|
||||
"""
|
||||
Smallest viable example. Mostly verifies that our request building works.
|
||||
"""
|
||||
scope, messages = self.run_daphne_request("GET", "/")
|
||||
self.assert_valid_http_scope(scope, "GET", "/")
|
||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||
|
||||
@given(
|
||||
request_path=http_strategies.http_path(),
|
||||
request_params=http_strategies.query_params()
|
||||
)
|
||||
@settings(max_examples=5, deadline=2000)
|
||||
def test_get_request(self, request_path, request_params):
|
||||
"""
|
||||
Tests a typical HTTP GET request, with a path and query parameters
|
||||
"""
|
||||
scope, messages = self.run_daphne_request("GET", request_path, params=request_params)
|
||||
self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
|
||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||
|
||||
@given(
|
||||
request_path=http_strategies.http_path(),
|
||||
request_body=http_strategies.http_body()
|
||||
)
|
||||
@settings(max_examples=5, deadline=2000)
|
||||
def test_post_request(self, request_path, request_body):
|
||||
"""
|
||||
Tests a typical HTTP POST request, with a path and body.
|
||||
"""
|
||||
scope, messages = self.run_daphne_request("POST", request_path, data=request_body)
|
||||
self.assert_valid_http_scope(scope, "POST", request_path)
|
||||
self.assert_valid_http_request_message(messages[0], body=request_body)
|
||||
|
||||
@given(request_headers=http_strategies.headers())
|
||||
@settings(max_examples=5, deadline=2000)
|
||||
def test_headers(self, request_headers):
|
||||
"""
|
||||
Tests that HTTP header fields are handled as specified
|
||||
"""
|
||||
request_path = "/te st-à/"
|
||||
scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers)
|
||||
self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers)
|
||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||
|
||||
# @given(request_headers=http_strategies.headers())
|
||||
# def test_duplicate_headers(self, request_headers):
|
||||
# """
|
||||
# Tests that duplicate header values are preserved
|
||||
# """
|
||||
# assume(len(request_headers) >= 2)
|
||||
# # Set all header field names to the same value
|
||||
# header_name = request_headers[0][0]
|
||||
# duplicated_headers = [(header_name, header[1]) for header in request_headers]
|
||||
|
||||
# request_method, request_path = "OPTIONS", "/te st-à/"
|
||||
# message = message_for_request(request_method, request_path, headers=duplicated_headers)
|
||||
|
||||
# self.assert_valid_http_request_message(
|
||||
# message, request_method, request_path, request_headers=duplicated_headers)
|
||||
|
||||
# @given(
|
||||
# request_method=http_strategies.http_method(),
|
||||
# request_path=http_strategies.http_path(),
|
||||
# request_params=http_strategies.query_params(),
|
||||
# request_headers=http_strategies.headers(),
|
||||
# request_body=http_strategies.http_body(),
|
||||
# )
|
||||
# # This test is slow enough that on Travis, hypothesis sometimes complains.
|
||||
# @settings(suppress_health_check=[HealthCheck.too_slow])
|
||||
# def test_kitchen_sink(
|
||||
# self, request_method, request_path, request_params, request_headers, request_body):
|
||||
# """
|
||||
# Throw everything at channels that we dare. The idea is that if a combination
|
||||
# of method/path/headers/body would break the spec, hypothesis will eventually find it.
|
||||
# """
|
||||
# request_headers.append(content_length_header(request_body))
|
||||
# message = message_for_request(
|
||||
# request_method, request_path, request_params, request_headers, request_body)
|
||||
|
||||
# self.assert_valid_http_request_message(
|
||||
# message, request_method, request_path, request_params, request_headers, request_body)
|
||||
|
||||
# def test_headers_are_lowercased_and_stripped(self):
|
||||
# request_method, request_path = "GET", "/"
|
||||
# headers = [("MYCUSTOMHEADER", " foobar ")]
|
||||
# message = message_for_request(request_method, request_path, headers=headers)
|
||||
|
||||
# self.assert_valid_http_request_message(
|
||||
# message, request_method, request_path, request_headers=headers)
|
||||
# # Note that Daphne returns a list of tuples here, which is fine, because the spec
|
||||
# # asks to treat them interchangeably.
|
||||
# assert message["headers"] == [(b"mycustomheader", b"foobar")]
|
||||
|
||||
# @given(daphne_path=http_strategies.http_path())
|
||||
# def test_root_path_header(self, daphne_path):
|
||||
# """
|
||||
# Tests root_path handling.
|
||||
# """
|
||||
# request_method, request_path = "GET", "/"
|
||||
# # Daphne-Root-Path must be URL encoded when submitting as HTTP header field
|
||||
# headers = [("Daphne-Root-Path", parse.quote(daphne_path.encode("utf8")))]
|
||||
# message = message_for_request(request_method, request_path, headers=headers)
|
||||
|
||||
# # Daphne-Root-Path is not included in the returned 'headers' section. So we expect
|
||||
# # empty headers.
|
||||
# expected_headers = []
|
||||
# self.assert_valid_http_request_message(
|
||||
# message, request_method, request_path, request_headers=expected_headers)
|
||||
# # And what we're looking for, root_path being set.
|
||||
# assert message["root_path"] == daphne_path
|
||||
|
||||
|
||||
# class TestProxyHandling(unittest.TestCase):
|
||||
# """
|
||||
# Tests that concern interaction of Daphne with proxies.
|
||||
|
||||
# They live in a separate test case, because they're not part of the spec.
|
||||
# """
|
||||
|
||||
# def setUp(self):
|
||||
# self.channel_layer = ChannelLayer()
|
||||
# self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
|
||||
# self.proto = self.factory.buildProtocol(("127.0.0.1", 0))
|
||||
# self.tr = proto_helpers.StringTransport()
|
||||
# self.proto.makeConnection(self.tr)
|
||||
|
||||
# def test_x_forwarded_for_ignored(self):
|
||||
# self.proto.dataReceived(
|
||||
# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
|
||||
# b"Host: somewhere.com\r\n" +
|
||||
# b"X-Forwarded-For: 10.1.2.3\r\n" +
|
||||
# b"X-Forwarded-Port: 80\r\n" +
|
||||
# b"\r\n"
|
||||
# )
|
||||
# # Get the resulting message off of the channel layer
|
||||
# _, message = self.channel_layer.receive(["http.request"])
|
||||
# self.assertEqual(message["client"], ["192.168.1.1", 54321])
|
||||
|
||||
# def test_x_forwarded_for_parsed(self):
|
||||
# self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
|
||||
# self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
|
||||
# self.proto.dataReceived(
|
||||
# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
|
||||
# b"Host: somewhere.com\r\n" +
|
||||
# b"X-Forwarded-For: 10.1.2.3\r\n" +
|
||||
# b"X-Forwarded-Port: 80\r\n" +
|
||||
# b"\r\n"
|
||||
# )
|
||||
# # Get the resulting message off of the channel layer
|
||||
# _, message = self.channel_layer.receive(["http.request"])
|
||||
# self.assertEqual(message["client"], ["10.1.2.3", 80])
|
||||
|
||||
# def test_x_forwarded_for_port_missing(self):
|
||||
# self.factory.proxy_forwarded_address_header = "X-Forwarded-For"
|
||||
# self.factory.proxy_forwarded_port_header = "X-Forwarded-Port"
|
||||
# self.proto.dataReceived(
|
||||
# b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
|
||||
# b"Host: somewhere.com\r\n" +
|
||||
# b"X-Forwarded-For: 10.1.2.3\r\n" +
|
||||
# b"\r\n"
|
||||
# )
|
||||
# # Get the resulting message off of the channel layer
|
||||
# _, message = self.channel_layer.receive(["http.request"])
|
||||
# self.assertEqual(message["client"], ["10.1.2.3", 0])
|
|
@ -1,11 +1,9 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
from unittest import TestCase
|
||||
import six
|
||||
|
||||
from twisted.web.http_headers import Headers
|
||||
from unittest import TestCase
|
||||
|
||||
from ..utils import parse_x_forwarded_for
|
||||
from daphne.utils import parse_x_forwarded_for
|
||||
|
||||
|
||||
class TestXForwardedForHttpParsing(TestCase):
|
||||
|
@ -20,7 +18,7 @@ class TestXForwardedForHttpParsing(TestCase):
|
|||
})
|
||||
result = parse_x_forwarded_for(headers)
|
||||
self.assertEqual(result, ["10.1.2.3", 1234])
|
||||
self.assertIsInstance(result[0], six.text_type)
|
||||
self.assertIsInstance(result[0], str)
|
||||
|
||||
def test_address_only(self):
|
||||
headers = Headers({
|
||||
|
@ -94,7 +92,7 @@ class TestXForwardedForWsParsing(TestCase):
|
|||
["1043::a321:0001", 0]
|
||||
)
|
||||
|
||||
def test_multiple_proxys(self):
|
||||
def test_multiple_proxies(self):
|
||||
headers = {
|
||||
b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4",
|
||||
}
|
Loading…
Reference in New Issue
Block a user