mirror of
https://github.com/django/daphne.git
synced 2024-11-25 09:13:44 +03:00
e1b77e930b
The entire body was previously read in memory which would lead the server to be killed by the scheduler. This change allows 8Kb chunks to be read until the entire body is consummed. Co-authored-by: Samori Gorse <samori@codeinstyle.io>
285 lines
9.7 KiB
Python
285 lines
9.7 KiB
Python
import socket
|
|
import struct
|
|
import time
|
|
import unittest
|
|
from http.client import HTTPConnection
|
|
from urllib import parse
|
|
|
|
from daphne.testing import DaphneTestingInstance, TestApplication
|
|
|
|
|
|
class DaphneTestCase(unittest.TestCase):
|
|
"""
|
|
Base class for Daphne integration test cases.
|
|
|
|
Boots up a copy of Daphne on a test port and sends it a request, and
|
|
retrieves the response. Uses a custom ASGI application and temporary files
|
|
to store/retrieve the request/response messages.
|
|
"""
|
|
|
|
### Plain HTTP helpers
|
|
|
|
def run_daphne_http(
|
|
self,
|
|
method,
|
|
path,
|
|
params,
|
|
body,
|
|
responses,
|
|
headers=None,
|
|
timeout=1,
|
|
xff=False,
|
|
request_buffer_size=None,
|
|
):
|
|
"""
|
|
Runs Daphne with the given request callback (given the base URL)
|
|
and response messages.
|
|
"""
|
|
with DaphneTestingInstance(
|
|
xff=xff, request_buffer_size=request_buffer_size
|
|
) as test_app:
|
|
# Add the response messages
|
|
test_app.add_send_messages(responses)
|
|
# Send it the request. We have to do this the long way to allow
|
|
# duplicate headers.
|
|
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
|
|
if params:
|
|
path += "?" + parse.urlencode(params, doseq=True)
|
|
conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True)
|
|
# Manually send over headers
|
|
if headers:
|
|
for header_name, header_value in headers:
|
|
conn.putheader(header_name, header_value)
|
|
# Send body if provided.
|
|
if body:
|
|
conn.putheader("Content-Length", str(len(body)))
|
|
conn.endheaders(message_body=body)
|
|
else:
|
|
conn.endheaders()
|
|
try:
|
|
response = conn.getresponse()
|
|
except socket.timeout:
|
|
# See if they left an exception for us to load
|
|
test_app.get_received()
|
|
raise RuntimeError(
|
|
"Daphne timed out handling request, no exception found."
|
|
)
|
|
# Return scope, messages, response
|
|
return test_app.get_received() + (response,)
|
|
|
|
def run_daphne_raw(self, data, *, responses=None, timeout=1):
|
|
"""
|
|
Runs Daphne and sends it the given raw bytestring over a socket.
|
|
Accepts list of response messages the application will reply with.
|
|
Returns what Daphne sends back.
|
|
"""
|
|
assert isinstance(data, bytes)
|
|
with DaphneTestingInstance() as test_app:
|
|
if responses is not None:
|
|
test_app.add_send_messages(responses)
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.settimeout(timeout)
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
s.connect((test_app.host, test_app.port))
|
|
s.send(data)
|
|
try:
|
|
return s.recv(1000000)
|
|
except socket.timeout:
|
|
raise RuntimeError(
|
|
"Daphne timed out handling raw request, no exception found."
|
|
)
|
|
|
|
def run_daphne_request(
|
|
self,
|
|
method,
|
|
path,
|
|
params=None,
|
|
body=None,
|
|
headers=None,
|
|
xff=False,
|
|
request_buffer_size=None,
|
|
):
|
|
"""
|
|
Convenience method for just testing request handling.
|
|
Returns (scope, messages)
|
|
"""
|
|
scope, messages, _ = self.run_daphne_http(
|
|
method=method,
|
|
path=path,
|
|
params=params,
|
|
body=body,
|
|
headers=headers,
|
|
xff=xff,
|
|
request_buffer_size=request_buffer_size,
|
|
responses=[
|
|
{"type": "http.response.start", "status": 200},
|
|
{"type": "http.response.body", "body": b"OK"},
|
|
],
|
|
)
|
|
return scope, messages
|
|
|
|
def run_daphne_response(self, response_messages):
|
|
"""
|
|
Convenience method for just testing response handling.
|
|
Returns (scope, messages)
|
|
"""
|
|
_, _, response = self.run_daphne_http(
|
|
method="GET", path="/", params={}, body=b"", responses=response_messages
|
|
)
|
|
return response
|
|
|
|
### WebSocket helpers
|
|
|
|
def websocket_handshake(
|
|
self,
|
|
test_app,
|
|
path="/",
|
|
params=None,
|
|
headers=None,
|
|
subprotocols=None,
|
|
timeout=1,
|
|
):
|
|
"""
|
|
Runs a WebSocket handshake negotiation and returns the raw socket
|
|
object & the selected subprotocol.
|
|
|
|
You'll need to inject an accept or reject message before this
|
|
to let it complete.
|
|
"""
|
|
# Send it the request. We have to do this the long way to allow
|
|
# duplicate headers.
|
|
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
|
|
if params:
|
|
path += "?" + parse.urlencode(params, doseq=True)
|
|
conn.putrequest("GET", path, skip_accept_encoding=True, skip_host=True)
|
|
# Do WebSocket handshake headers + any other headers
|
|
if headers is None:
|
|
headers = []
|
|
headers.extend(
|
|
[
|
|
(b"Host", b"example.com"),
|
|
(b"Upgrade", b"websocket"),
|
|
(b"Connection", b"Upgrade"),
|
|
(b"Sec-WebSocket-Key", b"x3JJHMbDL1EzLkh9GBhXDw=="),
|
|
(b"Sec-WebSocket-Version", b"13"),
|
|
(b"Origin", b"http://example.com"),
|
|
]
|
|
)
|
|
if subprotocols:
|
|
headers.append((b"Sec-WebSocket-Protocol", ", ".join(subprotocols)))
|
|
if headers:
|
|
for header_name, header_value in headers:
|
|
conn.putheader(header_name, header_value)
|
|
conn.endheaders()
|
|
# Read out the response
|
|
try:
|
|
response = conn.getresponse()
|
|
except socket.timeout:
|
|
# See if they left an exception for us to load
|
|
test_app.get_received()
|
|
raise RuntimeError("Daphne timed out handling request, no exception found.")
|
|
# Check we got a good response code
|
|
if response.status != 101:
|
|
raise RuntimeError("WebSocket upgrade did not result in status code 101")
|
|
# Prepare headers for subprotocol searching
|
|
response_headers = dict((n.lower(), v) for n, v in response.getheaders())
|
|
response.read()
|
|
assert not response.closed
|
|
# Return the raw socket and any subprotocol
|
|
return conn.sock, response_headers.get("sec-websocket-protocol", None)
|
|
|
|
def websocket_send_frame(self, sock, value):
|
|
"""
|
|
Sends a WebSocket text or binary frame. Cannot handle long frames.
|
|
"""
|
|
# Header and text opcode
|
|
if isinstance(value, str):
|
|
frame = b"\x81"
|
|
value = value.encode("utf8")
|
|
else:
|
|
frame = b"\x82"
|
|
# Length plus masking signal bit
|
|
frame += struct.pack("!B", len(value) | 0b10000000)
|
|
# Mask badly
|
|
frame += b"\0\0\0\0"
|
|
# Payload
|
|
frame += value
|
|
sock.sendall(frame)
|
|
|
|
def receive_from_socket(self, sock, length, timeout=1):
|
|
"""
|
|
Receives the given amount of bytes from the socket, or times out.
|
|
"""
|
|
buf = b""
|
|
started = time.time()
|
|
while len(buf) < length:
|
|
buf += sock.recv(length - len(buf))
|
|
time.sleep(0.001)
|
|
if time.time() - started > timeout:
|
|
raise ValueError("Timed out reading from socket")
|
|
return buf
|
|
|
|
def websocket_receive_frame(self, sock):
|
|
"""
|
|
Receives a WebSocket frame. Cannot handle long frames.
|
|
"""
|
|
# Read header byte
|
|
# TODO: Proper receive buffer handling
|
|
opcode = self.receive_from_socket(sock, 1)
|
|
if opcode in [b"\x81", b"\x82"]:
|
|
# Read length
|
|
length = struct.unpack("!B", self.receive_from_socket(sock, 1))[0]
|
|
# Read payload
|
|
payload = self.receive_from_socket(sock, length)
|
|
if opcode == b"\x81":
|
|
payload = payload.decode("utf8")
|
|
return payload
|
|
else:
|
|
raise ValueError("Unknown websocket opcode: %r" % opcode)
|
|
|
|
### Assertions and test management
|
|
|
|
def tearDown(self):
|
|
"""
|
|
Ensures any storage files are cleared.
|
|
"""
|
|
TestApplication.delete_setup()
|
|
TestApplication.delete_result()
|
|
|
|
def assert_is_ip_address(self, address):
|
|
"""
|
|
Tests whether a given address string is a valid IPv4 or IPv6 address.
|
|
"""
|
|
try:
|
|
socket.inet_aton(address)
|
|
except socket.error:
|
|
self.fail("'%s' is not a valid IP address." % address)
|
|
|
|
def assert_key_sets(self, required_keys, optional_keys, actual_keys):
|
|
"""
|
|
Asserts that all required_keys are in actual_keys, and that there
|
|
are no keys in actual_keys that aren't required or optional.
|
|
"""
|
|
present_keys = set(actual_keys)
|
|
# Make sure all required keys are present
|
|
self.assertTrue(required_keys <= present_keys)
|
|
# Assert that no other keys are present
|
|
self.assertEqual(set(), present_keys - required_keys - optional_keys)
|
|
|
|
def assert_valid_path(self, path):
|
|
"""
|
|
Checks the path is valid and already url-decoded.
|
|
"""
|
|
self.assertIsInstance(path, str)
|
|
# Assert that it's already url decoded
|
|
self.assertEqual(path, parse.unquote(path))
|
|
|
|
def assert_valid_address_and_port(self, host):
|
|
"""
|
|
Asserts the value is a valid (host, port) tuple.
|
|
"""
|
|
address, port = host
|
|
self.assertIsInstance(address, str)
|
|
self.assert_is_ip_address(address)
|
|
self.assertIsInstance(port, int)
|