diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index e466219..2f82250 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -227,11 +227,12 @@ class WebRequest(http.Request): """ Writes a received HTTP response back out to the transport. """ - if "status" in message: - if self._got_response_start: - raise ValueError("Got multiple Response messages for %s!" % self.reply_channel) + if not self._got_response_start: self._got_response_start = True - # Write code + if 'status' not in message: + raise ValueError("Specifying a status code is required for a Response message.") + + # Set HTTP status code self.setResponseCode(message['status']) # Write headers for header, value in message.get("headers", {}): @@ -240,9 +241,13 @@ class WebRequest(http.Request): header = header.encode("latin1") self.responseHeaders.addRawHeader(header, value) logger.debug("HTTP %s response started for %s", message['status'], self.reply_channel) + else: + if 'status' in message: + raise ValueError("Got multiple Response messages for %s!" % self.reply_channel) + # Write out body - if "content" in message: - http.Request.write(self, message['content']) + http.Request.write(self, message.get('content', b'')) + # End if there's no more content if not message.get("more_content", False): self.finish() diff --git a/daphne/tests/factories.py b/daphne/tests/factories.py index f972a88..fd9013a 100644 --- a/daphne/tests/factories.py +++ b/daphne/tests/factories.py @@ -14,7 +14,23 @@ def message_for_request(method, path, params=None, headers=None, body=None): that through daphne and returns the emitted channel message. """ request = _build_request(method, path, params, headers, body) - return _run_through_daphne(request, 'http.request') + message, factory, transport = _run_through_daphne(request, 'http.request') + return message + + +def response_for_message(message): + """ + Returns the raw HTTP response that Daphne constructs when sending a reply + to a HTTP request. + + The current approach actually first builds a HTTP request (similar to + message_for_request) because we need a valid reply channel. I'm sure + this can be streamlined, but it works for now. + """ + request = _build_request('GET', '/') + request_message, factory, transport = _run_through_daphne(request, 'http.request') + factory.dispatch_reply(request_message['reply_channel'], message) + return transport.value() def _build_request(method, path, params=None, headers=None, body=None): @@ -57,8 +73,8 @@ def _build_request(method, path, params=None, headers=None, body=None): quoted_path += b'?' + parse.urlencode(params) request = method.encode('ascii') + b' ' + quoted_path + b" HTTP/1.1\r\n" - for k, v in headers: - request += k.encode('ascii') + b': ' + v.encode('ascii') + b"\r\n" + for name, value in headers: + request += header_line(name, value) request += b'\r\n' @@ -68,6 +84,13 @@ def _build_request(method, path, params=None, headers=None, body=None): return request +def header_line(name, value): + """ + Given a header name and value, returns the line to use in a HTTP request or response. + """ + return name.encode('ascii') + b': ' + value.encode('ascii') + b"\r\n" + + def _run_through_daphne(request, channel_name): """ Returns Daphne's channel message for a given request. @@ -78,11 +101,11 @@ def _run_through_daphne(request, channel_name): channel_layer = ChannelLayer() factory = HTTPFactory(channel_layer) proto = factory.buildProtocol(('127.0.0.1', 0)) - tr = proto_helpers.StringTransport() - proto.makeConnection(tr) + transport = proto_helpers.StringTransport() + proto.makeConnection(transport) proto.dataReceived(request) _, message = channel_layer.receive([channel_name]) - return message + return message, factory, transport def content_length_header(body): diff --git a/daphne/tests/test_http_response.py b/daphne/tests/test_http_response.py index bd1005d..0212df6 100644 --- a/daphne/tests/test_http_response.py +++ b/daphne/tests/test_http_response.py @@ -7,9 +7,84 @@ from __future__ import unicode_literals from unittest import TestCase from asgiref.inmemory import ChannelLayer +from hypothesis import given from twisted.test import proto_helpers -from ..http_protocol import HTTPFactory +from daphne.http_protocol import HTTPFactory +from . import factories, http_strategies, testcases + + +class TestHTTPResponseSpec(testcases.ASGITestCase): + + def test_minimal_response(self): + """ + Smallest viable example. Mostly verifies that our response building works. + """ + message = {'status': 200} + response = factories.response_for_message(message) + self.assert_valid_http_response_message(message, response) + self.assertIn(b'200 OK', response) + # Assert that the response is the last of the chunks. + # N.b. at the time of writing, Daphne did not support multiple response chunks, + # but still sends with Transfer-Encoding: chunked if no Content-Length header + # is specified (and maybe even if specified). + self.assertTrue(response.endswith(b'0\r\n\r\n')) + + def test_status_code_required(self): + """ + Asserts that passing in the 'status' key is required. + + Previous versions of Daphne did not enforce this, so this test is here + to make sure it stays required. + """ + with self.assertRaises(ValueError): + factories.response_for_message({}) + + def test_status_code_is_transmitted(self): + """ + Tests that a custom status code is present in the response. + + We can't really use hypothesis to test all sorts of status codes, because a lot + of them have meaning that is respected by Twisted. E.g. setting 204 (No Content) + as a status code results in Twisted discarding the body. + """ + message = {'status': 201} # 'Created' + response = factories.response_for_message(message) + self.assert_valid_http_response_message(message, response) + self.assertIn(b'201 Created', response) + + @given(body=http_strategies.http_body()) + def test_body_is_transmitted(self, body): + message = {'status': 200, 'content': body.encode('ascii')} + response = factories.response_for_message(message) + self.assert_valid_http_response_message(message, response) + + @given(headers=http_strategies.headers()) + def test_headers(self, headers): + # The ASGI spec requires us to lowercase our header names + message = {'status': 200, 'headers': [(name.lower(), value) for name, value in headers]} + response = factories.response_for_message(message) + # The assert_ method does the heavy lifting of checking that headers are + # as expected. + self.assert_valid_http_response_message(message, response) + + @given( + headers=http_strategies.headers(), + body=http_strategies.http_body() + ) + def test_kitchen_sink(self, headers, body): + """ + This tests tries to let Hypothesis find combinations of variables that result + in breaking our assumptions. But responses are less exciting than responses, + so there's not a lot going on here. + """ + message = { + 'status': 202, # 'Accepted' + 'headers': [(name.lower(), value) for name, value in headers], + 'content': body.encode('ascii') + } + response = factories.response_for_message(message) + self.assert_valid_http_response_message(message, response) class TestHTTPResponse(TestCase): @@ -24,39 +99,6 @@ class TestHTTPResponse(TestCase): self.tr = proto_helpers.StringTransport() self.proto.makeConnection(self.tr) - def test_basic(self): - """ - Tests basic HTTP parsing - """ - # Send a simple request to the protocol - self.proto.dataReceived( - b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + - b"Host: somewhere.com\r\n" + - b"\r\n" - ) - # Get the resulting message off of the channel layer - _, message = self.channel_layer.receive(["http.request"]) - self.assertEqual(message['http_version'], "1.1") - self.assertEqual(message['method'], "GET") - self.assertEqual(message['scheme'], "http") - self.assertEqual(message['path'], "/te st-à/") - self.assertEqual(message['query_string'], b"foo=+bar") - self.assertEqual(message['headers'], [(b"host", b"somewhere.com")]) - self.assertFalse(message.get("body", None)) - self.assertTrue(message['reply_channel']) - # Send back an example response - self.factory.dispatch_reply( - message['reply_channel'], - { - "status": 201, - "status_text": b"Created", - "content": b"OH HAI", - "headers": [[b"X-Test", b"Boom!"]], - } - ) - # Make sure that comes back right on the protocol - self.assertEqual(self.tr.value(), b"HTTP/1.1 201 Created\r\nTransfer-Encoding: chunked\r\nX-Test: Boom!\r\n\r\n6\r\nOH HAI\r\n0\r\n\r\n") - def test_http_disconnect_sets_path_key(self): """ Tests http disconnect has the path key set, see https://channels.readthedocs.io/en/latest/asgi.html#disconnect diff --git a/daphne/tests/testcases.py b/daphne/tests/testcases.py index 6e2e7a8..df1750b 100644 --- a/daphne/tests/testcases.py +++ b/daphne/tests/testcases.py @@ -4,12 +4,13 @@ Contains a test case class to allow verifying ASGI messages from __future__ import unicode_literals from collections import defaultdict - import six -import socket from six.moves.urllib import parse +import socket import unittest +from . import factories + class ASGITestCase(unittest.TestCase): """ @@ -122,3 +123,27 @@ class ASGITestCase(unittest.TestCase): self.assertIsInstance(server_host, six.text_type) self.assert_is_ip_address(server_host) self.assertIsInstance(server_port, int) + + def assert_valid_http_response_message(self, message, response): + self.assertTrue(message) + self.assertTrue(response.startswith(b'HTTP')) + + status_code_bytes = six.text_type(message['status']).encode('ascii') + self.assertIn(status_code_bytes, response) + + if 'content' in message: + self.assertIn(message['content'], response) + + # Check that headers are in the given order. + # N.b. HTTP spec only enforces that the order of header values is kept, but + # the ASGI spec requires that order of all headers is kept. This code + # checks conformance with the stricter ASGI spec. + if 'headers' in message: + for name, value in message['headers']: + expected_header = factories.header_line(name, value) + # Daphne or Twisted turn our lower cased header names ('foo-bar') into title + # case ('Foo-Bar'). So technically we want to to match that the header name is + # present while ignoring casing, and want to ensure the value is present without + # altered casing. The approach below does this well enough. + self.assertIn(expected_header.lower(), response.lower()) + self.assertIn(value.encode('ascii'), response)