mirror of
				https://github.com/django/daphne.git
				synced 2025-10-25 04:51:05 +03:00 
			
		
		
		
	* Lint with pre-commit * Move existing tox qa hooks into pre-commit. * Set up GitHub Action based on https://github.com/pre-commit/action/ (we could also use https://pre-commit.ci ). * Add `pyupgrade` to drop old Python syntax. * Add `flake8-bugbear` plugin to prevent flake8 errors. * Drop custom GHA
		
			
				
	
	
		
			285 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			285 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import socket
 | |
| import struct
 | |
| import time
 | |
| import unittest
 | |
| from http.client import HTTPConnection
 | |
| from urllib import parse
 | |
| 
 | |
| from daphne.testing import DaphneTestingInstance, TestApplication
 | |
| 
 | |
| 
 | |
| class DaphneTestCase(unittest.TestCase):
 | |
|     """
 | |
|     Base class for Daphne integration test cases.
 | |
| 
 | |
|     Boots up a copy of Daphne on a test port and sends it a request, and
 | |
|     retrieves the response. Uses a custom ASGI application and temporary files
 | |
|     to store/retrieve the request/response messages.
 | |
|     """
 | |
| 
 | |
|     ### Plain HTTP helpers
 | |
| 
 | |
|     def run_daphne_http(
 | |
|         self,
 | |
|         method,
 | |
|         path,
 | |
|         params,
 | |
|         body,
 | |
|         responses,
 | |
|         headers=None,
 | |
|         timeout=1,
 | |
|         xff=False,
 | |
|         request_buffer_size=None,
 | |
|     ):
 | |
|         """
 | |
|         Runs Daphne with the given request callback (given the base URL)
 | |
|         and response messages.
 | |
|         """
 | |
|         with DaphneTestingInstance(
 | |
|             xff=xff, request_buffer_size=request_buffer_size
 | |
|         ) as test_app:
 | |
|             # Add the response messages
 | |
|             test_app.add_send_messages(responses)
 | |
|             # Send it the request. We have to do this the long way to allow
 | |
|             # duplicate headers.
 | |
|             conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
 | |
|             if params:
 | |
|                 path += "?" + parse.urlencode(params, doseq=True)
 | |
|             conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True)
 | |
|             # Manually send over headers
 | |
|             if headers:
 | |
|                 for header_name, header_value in headers:
 | |
|                     conn.putheader(header_name, header_value)
 | |
|             # Send body if provided.
 | |
|             if body:
 | |
|                 conn.putheader("Content-Length", str(len(body)))
 | |
|                 conn.endheaders(message_body=body)
 | |
|             else:
 | |
|                 conn.endheaders()
 | |
|             try:
 | |
|                 response = conn.getresponse()
 | |
|             except socket.timeout:
 | |
|                 # See if they left an exception for us to load
 | |
|                 test_app.get_received()
 | |
|                 raise RuntimeError(
 | |
|                     "Daphne timed out handling request, no exception found."
 | |
|                 )
 | |
|             # Return scope, messages, response
 | |
|             return test_app.get_received() + (response,)
 | |
| 
 | |
|     def run_daphne_raw(self, data, *, responses=None, timeout=1):
 | |
|         """
 | |
|         Runs Daphne and sends it the given raw bytestring over a socket.
 | |
|         Accepts list of response messages the application will reply with.
 | |
|         Returns what Daphne sends back.
 | |
|         """
 | |
|         assert isinstance(data, bytes)
 | |
|         with DaphneTestingInstance() as test_app:
 | |
|             if responses is not None:
 | |
|                 test_app.add_send_messages(responses)
 | |
|             s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 | |
|             s.settimeout(timeout)
 | |
|             s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 | |
|             s.connect((test_app.host, test_app.port))
 | |
|             s.send(data)
 | |
|             try:
 | |
|                 return s.recv(1000000)
 | |
|             except socket.timeout:
 | |
|                 raise RuntimeError(
 | |
|                     "Daphne timed out handling raw request, no exception found."
 | |
|                 )
 | |
| 
 | |
|     def run_daphne_request(
 | |
|         self,
 | |
|         method,
 | |
|         path,
 | |
|         params=None,
 | |
|         body=None,
 | |
|         headers=None,
 | |
|         xff=False,
 | |
|         request_buffer_size=None,
 | |
|     ):
 | |
|         """
 | |
|         Convenience method for just testing request handling.
 | |
|         Returns (scope, messages)
 | |
|         """
 | |
|         scope, messages, _ = self.run_daphne_http(
 | |
|             method=method,
 | |
|             path=path,
 | |
|             params=params,
 | |
|             body=body,
 | |
|             headers=headers,
 | |
|             xff=xff,
 | |
|             request_buffer_size=request_buffer_size,
 | |
|             responses=[
 | |
|                 {"type": "http.response.start", "status": 200},
 | |
|                 {"type": "http.response.body", "body": b"OK"},
 | |
|             ],
 | |
|         )
 | |
|         return scope, messages
 | |
| 
 | |
|     def run_daphne_response(self, response_messages):
 | |
|         """
 | |
|         Convenience method for just testing response handling.
 | |
|         Returns (scope, messages)
 | |
|         """
 | |
|         _, _, response = self.run_daphne_http(
 | |
|             method="GET", path="/", params={}, body=b"", responses=response_messages
 | |
|         )
 | |
|         return response
 | |
| 
 | |
|     ### WebSocket helpers
 | |
| 
 | |
|     def websocket_handshake(
 | |
|         self,
 | |
|         test_app,
 | |
|         path="/",
 | |
|         params=None,
 | |
|         headers=None,
 | |
|         subprotocols=None,
 | |
|         timeout=1,
 | |
|     ):
 | |
|         """
 | |
|         Runs a WebSocket handshake negotiation and returns the raw socket
 | |
|         object & the selected subprotocol.
 | |
| 
 | |
|         You'll need to inject an accept or reject message before this
 | |
|         to let it complete.
 | |
|         """
 | |
|         # Send it the request. We have to do this the long way to allow
 | |
|         # duplicate headers.
 | |
|         conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
 | |
|         if params:
 | |
|             path += "?" + parse.urlencode(params, doseq=True)
 | |
|         conn.putrequest("GET", path, skip_accept_encoding=True, skip_host=True)
 | |
|         # Do WebSocket handshake headers + any other headers
 | |
|         if headers is None:
 | |
|             headers = []
 | |
|         headers.extend(
 | |
|             [
 | |
|                 (b"Host", b"example.com"),
 | |
|                 (b"Upgrade", b"websocket"),
 | |
|                 (b"Connection", b"Upgrade"),
 | |
|                 (b"Sec-WebSocket-Key", b"x3JJHMbDL1EzLkh9GBhXDw=="),
 | |
|                 (b"Sec-WebSocket-Version", b"13"),
 | |
|                 (b"Origin", b"http://example.com"),
 | |
|             ]
 | |
|         )
 | |
|         if subprotocols:
 | |
|             headers.append((b"Sec-WebSocket-Protocol", ", ".join(subprotocols)))
 | |
|         if headers:
 | |
|             for header_name, header_value in headers:
 | |
|                 conn.putheader(header_name, header_value)
 | |
|         conn.endheaders()
 | |
|         # Read out the response
 | |
|         try:
 | |
|             response = conn.getresponse()
 | |
|         except socket.timeout:
 | |
|             # See if they left an exception for us to load
 | |
|             test_app.get_received()
 | |
|             raise RuntimeError("Daphne timed out handling request, no exception found.")
 | |
|         # Check we got a good response code
 | |
|         if response.status != 101:
 | |
|             raise RuntimeError("WebSocket upgrade did not result in status code 101")
 | |
|         # Prepare headers for subprotocol searching
 | |
|         response_headers = {n.lower(): v for n, v in response.getheaders()}
 | |
|         response.read()
 | |
|         assert not response.closed
 | |
|         # Return the raw socket and any subprotocol
 | |
|         return conn.sock, response_headers.get("sec-websocket-protocol", None)
 | |
| 
 | |
|     def websocket_send_frame(self, sock, value):
 | |
|         """
 | |
|         Sends a WebSocket text or binary frame. Cannot handle long frames.
 | |
|         """
 | |
|         # Header and text opcode
 | |
|         if isinstance(value, str):
 | |
|             frame = b"\x81"
 | |
|             value = value.encode("utf8")
 | |
|         else:
 | |
|             frame = b"\x82"
 | |
|         # Length plus masking signal bit
 | |
|         frame += struct.pack("!B", len(value) | 0b10000000)
 | |
|         # Mask badly
 | |
|         frame += b"\0\0\0\0"
 | |
|         # Payload
 | |
|         frame += value
 | |
|         sock.sendall(frame)
 | |
| 
 | |
|     def receive_from_socket(self, sock, length, timeout=1):
 | |
|         """
 | |
|         Receives the given amount of bytes from the socket, or times out.
 | |
|         """
 | |
|         buf = b""
 | |
|         started = time.time()
 | |
|         while len(buf) < length:
 | |
|             buf += sock.recv(length - len(buf))
 | |
|             time.sleep(0.001)
 | |
|             if time.time() - started > timeout:
 | |
|                 raise ValueError("Timed out reading from socket")
 | |
|         return buf
 | |
| 
 | |
|     def websocket_receive_frame(self, sock):
 | |
|         """
 | |
|         Receives a WebSocket frame. Cannot handle long frames.
 | |
|         """
 | |
|         # Read header byte
 | |
|         # TODO: Proper receive buffer handling
 | |
|         opcode = self.receive_from_socket(sock, 1)
 | |
|         if opcode in [b"\x81", b"\x82"]:
 | |
|             # Read length
 | |
|             length = struct.unpack("!B", self.receive_from_socket(sock, 1))[0]
 | |
|             # Read payload
 | |
|             payload = self.receive_from_socket(sock, length)
 | |
|             if opcode == b"\x81":
 | |
|                 payload = payload.decode("utf8")
 | |
|             return payload
 | |
|         else:
 | |
|             raise ValueError("Unknown websocket opcode: %r" % opcode)
 | |
| 
 | |
|     ### Assertions and test management
 | |
| 
 | |
|     def tearDown(self):
 | |
|         """
 | |
|         Ensures any storage files are cleared.
 | |
|         """
 | |
|         TestApplication.delete_setup()
 | |
|         TestApplication.delete_result()
 | |
| 
 | |
|     def assert_is_ip_address(self, address):
 | |
|         """
 | |
|         Tests whether a given address string is a valid IPv4 or IPv6 address.
 | |
|         """
 | |
|         try:
 | |
|             socket.inet_aton(address)
 | |
|         except OSError:
 | |
|             self.fail("'%s' is not a valid IP address." % address)
 | |
| 
 | |
|     def assert_key_sets(self, required_keys, optional_keys, actual_keys):
 | |
|         """
 | |
|         Asserts that all required_keys are in actual_keys, and that there
 | |
|         are no keys in actual_keys that aren't required or optional.
 | |
|         """
 | |
|         present_keys = set(actual_keys)
 | |
|         # Make sure all required keys are present
 | |
|         self.assertTrue(required_keys <= present_keys)
 | |
|         # Assert that no other keys are present
 | |
|         self.assertEqual(set(), present_keys - required_keys - optional_keys)
 | |
| 
 | |
|     def assert_valid_path(self, path):
 | |
|         """
 | |
|         Checks the path is valid and already url-decoded.
 | |
|         """
 | |
|         self.assertIsInstance(path, str)
 | |
|         # Assert that it's already url decoded
 | |
|         self.assertEqual(path, parse.unquote(path))
 | |
| 
 | |
|     def assert_valid_address_and_port(self, host):
 | |
|         """
 | |
|         Asserts the value is a valid (host, port) tuple.
 | |
|         """
 | |
|         address, port = host
 | |
|         self.assertIsInstance(address, str)
 | |
|         self.assert_is_ip_address(address)
 | |
|         self.assertIsInstance(port, int)
 |