mirror of
https://github.com/django/daphne.git
synced 2024-11-21 15:36:33 +03:00
parent
9a282dd627
commit
ef24796243
|
@ -9,7 +9,7 @@ from twisted.protocols.policies import ProtocolWrapper
|
||||||
from twisted.web import http
|
from twisted.web import http
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from .utils import parse_x_forwarded_for
|
from .utils import HEADER_NAME_RE, parse_x_forwarded_for
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -69,6 +69,13 @@ class WebRequest(http.Request):
|
||||||
def process(self):
|
def process(self):
|
||||||
try:
|
try:
|
||||||
self.request_start = time.time()
|
self.request_start = time.time()
|
||||||
|
|
||||||
|
# Validate header names.
|
||||||
|
for name, _ in self.requestHeaders.getAllRawHeaders():
|
||||||
|
if not HEADER_NAME_RE.fullmatch(name):
|
||||||
|
self.basic_error(400, b"Bad Request", "Invalid header name")
|
||||||
|
return
|
||||||
|
|
||||||
# Get upgrade header
|
# Get upgrade header
|
||||||
upgrade_header = None
|
upgrade_header = None
|
||||||
if self.requestHeaders.hasHeader(b"Upgrade"):
|
if self.requestHeaders.hasHeader(b"Upgrade"):
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
import re
|
||||||
|
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
|
|
||||||
|
# Header name regex as per h11.
|
||||||
|
# https://github.com/python-hyper/h11/blob/a2c68948accadc3876dffcf979d98002e4a4ed27/h11/_abnf.py#L10-L21
|
||||||
|
HEADER_NAME_RE = re.compile(rb"[-!#$%&'*+.^_`|~0-9a-zA-Z]+")
|
||||||
|
|
||||||
|
|
||||||
def import_by_path(path):
|
def import_by_path(path):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -310,3 +310,15 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
|
b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
|
||||||
)
|
)
|
||||||
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
|
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
|
||||||
|
|
||||||
|
def test_invalid_header_name(self):
|
||||||
|
"""
|
||||||
|
Tests that requests with invalid header names fail.
|
||||||
|
"""
|
||||||
|
# Test cases follow those used by h11
|
||||||
|
# https://github.com/python-hyper/h11/blob/a2c68948accadc3876dffcf979d98002e4a4ed27/h11/tests/test_headers.py#L24-L35
|
||||||
|
for header_name in [b"foo bar", b"foo\x00bar", b"foo\xffbar", b"foo\x01bar"]:
|
||||||
|
response = self.run_daphne_raw(
|
||||||
|
f"GET / HTTP/1.0\r\n{header_name}: baz\r\n\r\n".encode("ascii")
|
||||||
|
)
|
||||||
|
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user