diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index a94a466..0bf3c45 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -28,12 +28,13 @@ class WebSocketProtocol(WebSocketServerProtocol): self.last_data = time.time() try: # Sanitize and decode headers - clean_headers = {} + self.clean_headers = [] for name, value in request.headers.items(): + name = name.encode("ascii") # Prevent CVE-2015-0219 - if "_" in name: + if b"_" in name: continue - clean_headers[name.lower()] = value.encode("latin1") + self.clean_headers.append((name.lower(), value.encode("latin1"))) # Reconstruct query string # TODO: get autobahn to provide it raw query_string = urlencode(request.params, doseq=True).encode("ascii") @@ -52,7 +53,7 @@ class WebSocketProtocol(WebSocketServerProtocol): self.path = request.path.encode("ascii") self.request_info = { "path": self.unquote(self.path), - "headers": clean_headers, + "headers": self.clean_headers, "query_string": self.unquote(query_string), "client": self.client_addr, "server": self.server_addr, @@ -65,7 +66,10 @@ class WebSocketProtocol(WebSocketServerProtocol): logger.error(traceback.format_exc()) raise - ws_protocol = clean_headers.get('sec-websocket-protocol') + ws_protocol = None + for header, value in self.clean_headers: + if header == 'sec-websocket-protocol': + ws_protocol = value if ws_protocol and ws_protocol in self.factory.protocols: return ws_protocol