diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index f0657fd..b7da1bf 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -9,7 +9,7 @@ from twisted.protocols.policies import ProtocolWrapper from twisted.web import http from zope.interface import implementer -from .utils import parse_x_forwarded_for +from .utils import HEADER_NAME_RE, parse_x_forwarded_for logger = logging.getLogger(__name__) @@ -69,6 +69,13 @@ class WebRequest(http.Request): def process(self): try: self.request_start = time.time() + + # Validate header names. + for name, _ in self.requestHeaders.getAllRawHeaders(): + if not HEADER_NAME_RE.fullmatch(name): + self.basic_error(400, b"Bad Request", "Invalid header name") + return + # Get upgrade header upgrade_header = None if self.requestHeaders.hasHeader(b"Upgrade"): diff --git a/daphne/utils.py b/daphne/utils.py index 81f1f9d..0699314 100644 --- a/daphne/utils.py +++ b/daphne/utils.py @@ -1,7 +1,12 @@ import importlib +import re from twisted.web.http_headers import Headers +# Header name regex as per h11. +# https://github.com/python-hyper/h11/blob/a2c68948accadc3876dffcf979d98002e4a4ed27/h11/_abnf.py#L10-L21 +HEADER_NAME_RE = re.compile(rb"[-!#$%&'*+.^_`|~0-9a-zA-Z]+") + def import_by_path(path): """ diff --git a/tests/test_http_request.py b/tests/test_http_request.py index 52f6dd1..dca5bd7 100644 --- a/tests/test_http_request.py +++ b/tests/test_http_request.py @@ -310,3 +310,15 @@ class TestHTTPRequest(DaphneTestCase): b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n" ) self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request")) + + def test_invalid_header_name(self): + """ + Tests that requests with invalid header names fail. + """ + # Test cases follow those used by h11 + # https://github.com/python-hyper/h11/blob/a2c68948accadc3876dffcf979d98002e4a4ed27/h11/tests/test_headers.py#L24-L35 + for header_name in [b"foo bar", b"foo\x00bar", b"foo\xffbar", b"foo\x01bar"]: + response = self.run_daphne_raw( + f"GET / HTTP/1.0\r\n{header_name}: baz\r\n\r\n".encode("ascii") + ) + self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))