diff --git a/daphne/server.py b/daphne/server.py index 133762d..3f27bf5 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -211,9 +211,28 @@ class Server(object): # Don't do anything if the connection is closed if self.connections[protocol].get("disconnected", None): return + self.check_headers_type(message) # Let the protocol handle it protocol.handle_reply(message) + @staticmethod + def check_headers_type(message): + if not message["type"] == "http.response.start": + return + for k, v in message.get("headers", []): + if not isinstance(k, bytes): + raise ValueError( + "Header name '{}' expected to be `bytes`, but got `{}`".format( + k, type(k) + ) + ) + if not isinstance(v, bytes): + raise ValueError( + "Header value '{}' expected to be `bytes`, but got `{}`".format( + v, type(v) + ) + ) + ### Utility def application_checker(self): diff --git a/tests/http_base.py b/tests/http_base.py index 866a066..e6fc92b 100644 --- a/tests/http_base.py +++ b/tests/http_base.py @@ -37,12 +37,10 @@ class DaphneTestCase(unittest.TestCase): if params: path += "?" + parse.urlencode(params, doseq=True) conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True) - # Manually send over headers (encoding any non-safe values as best we can) + # Manually send over headers if headers: for header_name, header_value in headers: - conn.putheader( - header_name.encode("utf8"), header_value.encode("utf8") - ) + conn.putheader(header_name, header_value) # Send body if provided. if body: conn.putheader("Content-Length", str(len(body))) @@ -140,19 +138,19 @@ class DaphneTestCase(unittest.TestCase): headers = [] headers.extend( [ - ("Host", "example.com"), - ("Upgrade", "websocket"), - ("Connection", "Upgrade"), - ("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="), - ("Sec-WebSocket-Version", "13"), - ("Origin", "http://example.com"), + (b"Host", b"example.com"), + (b"Upgrade", b"websocket"), + (b"Connection", b"Upgrade"), + (b"Sec-WebSocket-Key", b"x3JJHMbDL1EzLkh9GBhXDw=="), + (b"Sec-WebSocket-Version", b"13"), + (b"Origin", b"http://example.com"), ] ) if subprotocols: - headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols))) + headers.append((b"Sec-WebSocket-Protocol", ", ".join(subprotocols))) if headers: for header_name, header_value in headers: - conn.putheader(header_name.encode("utf8"), header_value.encode("utf8")) + conn.putheader(header_name, header_value) conn.endheaders() # Read out the response try: diff --git a/tests/http_strategies.py b/tests/http_strategies.py index d78ac10..e9d8736 100644 --- a/tests/http_strategies.py +++ b/tests/http_strategies.py @@ -92,7 +92,7 @@ def header_name(): """ return strategies.text( alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30 - ) + ).map(lambda s: s.encode("utf-8")) def header_value(): @@ -102,15 +102,19 @@ def header_value(): "For example, the Apache 2.3 server by default limits the size of each field to 8190 bytes" https://en.wikipedia.org/wiki/List_of_HTTP_header_fields """ - return strategies.text( - alphabet=string.ascii_letters - + string.digits - + string.punctuation.replace(",", "") - + " /t", - min_size=1, - average_size=40, - max_size=8190, - ).filter(lambda s: len(s.encode("utf8")) < 8190) + return ( + strategies.text( + alphabet=string.ascii_letters + + string.digits + + string.punctuation.replace(",", "") + + " /t", + min_size=1, + average_size=40, + max_size=8190, + ) + .map(lambda s: s.encode("utf-8")) + .filter(lambda s: len(s) < 8190) + ) def headers(): diff --git a/tests/test_http_request.py b/tests/test_http_request.py index e02b8b6..c1efd64 100644 --- a/tests/test_http_request.py +++ b/tests/test_http_request.py @@ -63,8 +63,8 @@ class TestHTTPRequest(DaphneTestCase): transformed_scope_headers[name].append(value) transformed_request_headers = collections.defaultdict(list) for name, value in headers or []: - expected_name = name.lower().strip().encode("ascii") - expected_value = value.strip().encode("ascii") + expected_name = name.lower().strip() + expected_value = value.strip() transformed_request_headers[expected_name].append(expected_value) for name, value in transformed_request_headers.items(): self.assertIn(name, transformed_scope_headers) @@ -209,7 +209,7 @@ class TestHTTPRequest(DaphneTestCase): """ Make sure headers are normalized as the spec says they are. """ - headers = [("MYCUSTOMHEADER", " foobar ")] + headers = [(b"MYCUSTOMHEADER", b" foobar ")] scope, messages = self.run_daphne_request("GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -237,7 +237,7 @@ class TestHTTPRequest(DaphneTestCase): """ Make sure that, by default, X-Forwarded-For is ignored. """ - headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] + headers = [[b"X-Forwarded-For", b"10.1.2.3"], [b"X-Forwarded-Port", b"80"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -248,7 +248,7 @@ class TestHTTPRequest(DaphneTestCase): """ When X-Forwarded-For is enabled, make sure it is respected. """ - headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] + headers = [[b"X-Forwarded-For", b"10.1.2.3"], [b"X-Forwarded-Port", b"80"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -260,7 +260,7 @@ class TestHTTPRequest(DaphneTestCase): When X-Forwarded-For is enabled but only the host is passed, make sure that at least makes it through. """ - headers = [["X-Forwarded-For", "10.1.2.3"]] + headers = [[b"X-Forwarded-For", b"10.1.2.3"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") diff --git a/tests/test_http_response.py b/tests/test_http_response.py index 3576697..afb8e39 100644 --- a/tests/test_http_response.py +++ b/tests/test_http_response.py @@ -19,10 +19,16 @@ class TestHTTPResponse(DaphneTestCase): [ (name.lower(), value.strip()) for name, value in headers - if name.lower() != "transfer-encoding" + if name.lower() != b"transfer-encoding" ] ) + def encode_headers(self, headers): + def encode(s): + return s if isinstance(s, bytes) else s.encode("utf-8") + + return [[encode(k), encode(v)] for k, v in headers] + def test_minimal_response(self): """ Smallest viable example. Mostly verifies that our response building works. @@ -124,6 +130,42 @@ class TestHTTPResponse(DaphneTestCase): ) # Check headers in a sensible way. Ignore transfer-encoding. self.assertEqual( - self.normalize_headers(response.getheaders()), + self.normalize_headers(self.encode_headers(response.getheaders())), self.normalize_headers(headers), ) + + def test_headers_type(self): + """ + Headers should be `bytes` + """ + with self.assertRaises(ValueError) as context: + self.run_daphne_response( + [ + { + "type": "http.response.start", + "status": 200, + "headers": [["foo", b"bar"]], + }, + {"type": "http.response.body", "body": b""}, + ] + ) + self.assertEqual( + str(context.exception), + "Header name 'foo' expected to be `bytes`, but got ``", + ) + + with self.assertRaises(ValueError) as context: + self.run_daphne_response( + [ + { + "type": "http.response.start", + "status": 200, + "headers": [[b"foo", True]], + }, + {"type": "http.response.body", "body": b""}, + ] + ) + self.assertEqual( + str(context.exception), + "Header value 'True' expected to be `bytes`, but got ``", + ) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 80ec21d..69a54f8 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -56,8 +56,8 @@ class TestWebsocket(DaphneTestCase): transformed_scope_headers[name].append(bit.strip()) transformed_request_headers = collections.defaultdict(list) for name, value in headers or []: - expected_name = name.lower().strip().encode("ascii") - expected_value = value.strip().encode("ascii") + expected_name = name.lower().strip() + expected_value = value.strip() # Make sure to split out any headers collapsed with commas transformed_request_headers.setdefault(expected_name, []) for bit in expected_value.split(b","):