Fixed #229: Allow bytes headers only

Previously Daphne was too lax and would happily accept strings too.
This commit is contained in:
Imblc 2018-09-28 23:45:03 +07:00 committed by Andrew Godwin
parent 3e4aab95e2
commit e93643ff5a
6 changed files with 95 additions and 32 deletions

View File

@ -211,9 +211,28 @@ class Server(object):
# Don't do anything if the connection is closed # Don't do anything if the connection is closed
if self.connections[protocol].get("disconnected", None): if self.connections[protocol].get("disconnected", None):
return return
self.check_headers_type(message)
# Let the protocol handle it # Let the protocol handle it
protocol.handle_reply(message) protocol.handle_reply(message)
@staticmethod
def check_headers_type(message):
if not message["type"] == "http.response.start":
return
for k, v in message.get("headers", []):
if not isinstance(k, bytes):
raise ValueError(
"Header name '{}' expected to be `bytes`, but got `{}`".format(
k, type(k)
)
)
if not isinstance(v, bytes):
raise ValueError(
"Header value '{}' expected to be `bytes`, but got `{}`".format(
v, type(v)
)
)
### Utility ### Utility
def application_checker(self): def application_checker(self):

View File

@ -37,12 +37,10 @@ class DaphneTestCase(unittest.TestCase):
if params: if params:
path += "?" + parse.urlencode(params, doseq=True) path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True) conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True)
# Manually send over headers (encoding any non-safe values as best we can) # Manually send over headers
if headers: if headers:
for header_name, header_value in headers: for header_name, header_value in headers:
conn.putheader( conn.putheader(header_name, header_value)
header_name.encode("utf8"), header_value.encode("utf8")
)
# Send body if provided. # Send body if provided.
if body: if body:
conn.putheader("Content-Length", str(len(body))) conn.putheader("Content-Length", str(len(body)))
@ -140,19 +138,19 @@ class DaphneTestCase(unittest.TestCase):
headers = [] headers = []
headers.extend( headers.extend(
[ [
("Host", "example.com"), (b"Host", b"example.com"),
("Upgrade", "websocket"), (b"Upgrade", b"websocket"),
("Connection", "Upgrade"), (b"Connection", b"Upgrade"),
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="), (b"Sec-WebSocket-Key", b"x3JJHMbDL1EzLkh9GBhXDw=="),
("Sec-WebSocket-Version", "13"), (b"Sec-WebSocket-Version", b"13"),
("Origin", "http://example.com"), (b"Origin", b"http://example.com"),
] ]
) )
if subprotocols: if subprotocols:
headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols))) headers.append((b"Sec-WebSocket-Protocol", ", ".join(subprotocols)))
if headers: if headers:
for header_name, header_value in headers: for header_name, header_value in headers:
conn.putheader(header_name.encode("utf8"), header_value.encode("utf8")) conn.putheader(header_name, header_value)
conn.endheaders() conn.endheaders()
# Read out the response # Read out the response
try: try:

View File

@ -92,7 +92,7 @@ def header_name():
""" """
return strategies.text( return strategies.text(
alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30 alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30
) ).map(lambda s: s.encode("utf-8"))
def header_value(): def header_value():
@ -102,7 +102,8 @@ def header_value():
"For example, the Apache 2.3 server by default limits the size of each field to 8190 bytes" "For example, the Apache 2.3 server by default limits the size of each field to 8190 bytes"
https://en.wikipedia.org/wiki/List_of_HTTP_header_fields https://en.wikipedia.org/wiki/List_of_HTTP_header_fields
""" """
return strategies.text( return (
strategies.text(
alphabet=string.ascii_letters alphabet=string.ascii_letters
+ string.digits + string.digits
+ string.punctuation.replace(",", "") + string.punctuation.replace(",", "")
@ -110,7 +111,10 @@ def header_value():
min_size=1, min_size=1,
average_size=40, average_size=40,
max_size=8190, max_size=8190,
).filter(lambda s: len(s.encode("utf8")) < 8190) )
.map(lambda s: s.encode("utf-8"))
.filter(lambda s: len(s) < 8190)
)
def headers(): def headers():

View File

@ -63,8 +63,8 @@ class TestHTTPRequest(DaphneTestCase):
transformed_scope_headers[name].append(value) transformed_scope_headers[name].append(value)
transformed_request_headers = collections.defaultdict(list) transformed_request_headers = collections.defaultdict(list)
for name, value in headers or []: for name, value in headers or []:
expected_name = name.lower().strip().encode("ascii") expected_name = name.lower().strip()
expected_value = value.strip().encode("ascii") expected_value = value.strip()
transformed_request_headers[expected_name].append(expected_value) transformed_request_headers[expected_name].append(expected_value)
for name, value in transformed_request_headers.items(): for name, value in transformed_request_headers.items():
self.assertIn(name, transformed_scope_headers) self.assertIn(name, transformed_scope_headers)
@ -209,7 +209,7 @@ class TestHTTPRequest(DaphneTestCase):
""" """
Make sure headers are normalized as the spec says they are. Make sure headers are normalized as the spec says they are.
""" """
headers = [("MYCUSTOMHEADER", " foobar ")] headers = [(b"MYCUSTOMHEADER", b" foobar ")]
scope, messages = self.run_daphne_request("GET", "/", headers=headers) scope, messages = self.run_daphne_request("GET", "/", headers=headers)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")
@ -237,7 +237,7 @@ class TestHTTPRequest(DaphneTestCase):
""" """
Make sure that, by default, X-Forwarded-For is ignored. Make sure that, by default, X-Forwarded-For is ignored.
""" """
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] headers = [[b"X-Forwarded-For", b"10.1.2.3"], [b"X-Forwarded-Port", b"80"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers) scope, messages = self.run_daphne_request("GET", "/", headers=headers)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")
@ -248,7 +248,7 @@ class TestHTTPRequest(DaphneTestCase):
""" """
When X-Forwarded-For is enabled, make sure it is respected. When X-Forwarded-For is enabled, make sure it is respected.
""" """
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] headers = [[b"X-Forwarded-For", b"10.1.2.3"], [b"X-Forwarded-Port", b"80"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")
@ -260,7 +260,7 @@ class TestHTTPRequest(DaphneTestCase):
When X-Forwarded-For is enabled but only the host is passed, make sure When X-Forwarded-For is enabled but only the host is passed, make sure
that at least makes it through. that at least makes it through.
""" """
headers = [["X-Forwarded-For", "10.1.2.3"]] headers = [[b"X-Forwarded-For", b"10.1.2.3"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")

View File

@ -19,10 +19,16 @@ class TestHTTPResponse(DaphneTestCase):
[ [
(name.lower(), value.strip()) (name.lower(), value.strip())
for name, value in headers for name, value in headers
if name.lower() != "transfer-encoding" if name.lower() != b"transfer-encoding"
] ]
) )
def encode_headers(self, headers):
def encode(s):
return s if isinstance(s, bytes) else s.encode("utf-8")
return [[encode(k), encode(v)] for k, v in headers]
def test_minimal_response(self): def test_minimal_response(self):
""" """
Smallest viable example. Mostly verifies that our response building works. Smallest viable example. Mostly verifies that our response building works.
@ -124,6 +130,42 @@ class TestHTTPResponse(DaphneTestCase):
) )
# Check headers in a sensible way. Ignore transfer-encoding. # Check headers in a sensible way. Ignore transfer-encoding.
self.assertEqual( self.assertEqual(
self.normalize_headers(response.getheaders()), self.normalize_headers(self.encode_headers(response.getheaders())),
self.normalize_headers(headers), self.normalize_headers(headers),
) )
def test_headers_type(self):
"""
Headers should be `bytes`
"""
with self.assertRaises(ValueError) as context:
self.run_daphne_response(
[
{
"type": "http.response.start",
"status": 200,
"headers": [["foo", b"bar"]],
},
{"type": "http.response.body", "body": b""},
]
)
self.assertEqual(
str(context.exception),
"Header name 'foo' expected to be `bytes`, but got `<class 'str'>`",
)
with self.assertRaises(ValueError) as context:
self.run_daphne_response(
[
{
"type": "http.response.start",
"status": 200,
"headers": [[b"foo", True]],
},
{"type": "http.response.body", "body": b""},
]
)
self.assertEqual(
str(context.exception),
"Header value 'True' expected to be `bytes`, but got `<class 'bool'>`",
)

View File

@ -56,8 +56,8 @@ class TestWebsocket(DaphneTestCase):
transformed_scope_headers[name].append(bit.strip()) transformed_scope_headers[name].append(bit.strip())
transformed_request_headers = collections.defaultdict(list) transformed_request_headers = collections.defaultdict(list)
for name, value in headers or []: for name, value in headers or []:
expected_name = name.lower().strip().encode("ascii") expected_name = name.lower().strip()
expected_value = value.strip().encode("ascii") expected_value = value.strip()
# Make sure to split out any headers collapsed with commas # Make sure to split out any headers collapsed with commas
transformed_request_headers.setdefault(expected_name, []) transformed_request_headers.setdefault(expected_name, [])
for bit in expected_value.split(b","): for bit in expected_value.split(b","):