diff --git a/daphne/cli.py b/daphne/cli.py index 71faf0c..fd992bd 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -42,7 +42,7 @@ class CommandLineInterface(object): '--websocket_timeout', type=int, help='Maximum time to allow a websocket to be connected. -1 for infinite.', - default=None, + default=86400, ) self.parser.add_argument( '--websocket_connect_timeout', diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 6b4ece6..933f78b 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -30,8 +30,6 @@ class WebRequest(http.Request): GET and POST out. """ - application_type = "http" - error_template = """ @@ -138,8 +136,6 @@ class WebRequest(http.Request): # Boring old HTTP. else: - # Create application to handle this connection - self.application_queue = self.server.create_application(self) # Sanitize and decode headers, potentially extracting root path self.clean_headers = [] self.root_path = self.server.root_path @@ -154,21 +150,25 @@ class WebRequest(http.Request): self.clean_headers.append((name.lower(), value)) logger.debug("HTTP %s request for %s", self.method, self.client_addr) self.content.seek(0, 0) + # Work out the application scope and create application + self.application_queue = self.server.create_application(self, { + "type": "http", + # TODO: Correctly say if it's 1.1 or 1.0 + "http_version": self.clientproto.split(b"/")[-1].decode("ascii"), + "method": self.method.decode("ascii"), + "path": self.unquote(self.path), + "root_path": self.root_path, + "scheme": "https" if self.isSecure() else "http", + "query_string": self.query_string, + "headers": self.clean_headers, + "client": self.client_addr, + "server": self.server_addr, + }) # Run application against request self.application_queue.put_nowait( { "type": "http.request", - # TODO: Correctly say if it's 1.1 or 1.0 - "http_version": self.clientproto.split(b"/")[-1].decode("ascii"), - "method": self.method.decode("ascii"), - "path": self.unquote(self.path), - "root_path": self.root_path, - "scheme": "https" if self.isSecure() else "http", - "query_string": self.query_string, - "headers": self.clean_headers, "body": self.content.read(), - "client": self.client_addr, - "server": self.server_addr, }, ) except Exception: diff --git a/daphne/server.py b/daphne/server.py index 8f43ee6..b25019f 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -91,7 +91,7 @@ class Server(object): reactor.callLater(2, self.timeout_checker) for socket_description in self.endpoints: - logger.info("Listening on endpoint %s" % socket_description) + logger.info("Listening on endpoint %s", socket_description) ep = serverFromString(reactor, str(socket_description)) self.listeners.append(ep.listen(self.http_factory)) @@ -123,25 +123,23 @@ class Server(object): ### Internal event/message handling - def create_application(self, protocol): + def create_application(self, protocol, scope): """ Creates a new application instance that fronts a Protocol instance for one of our supported protocols. Pass it the protocol, and it will work out the type, supply appropriate callables, and return you the application's input queue """ - # Make sure the protocol defines a application type - assert protocol.application_type is not None # Make sure the protocol has not had another application made for it assert protocol not in self.application_instances # Make an instance of the application input_queue = asyncio.Queue() - application_instance = asyncio.ensure_future(self.application( - type=protocol.application_type, - next=input_queue.get, - reply=lambda message: self.handle_reply(protocol, message), + application_instance = self.application(scope=scope) + # Run it, and stash the future for later checking + self.application_instances[protocol] = asyncio.ensure_future(application_instance( + receive=input_queue.get, + send=lambda message: self.handle_reply(protocol, message), ), loop=asyncio.get_event_loop()) - self.application_instances[protocol] = application_instance return input_queue async def handle_reply(self, protocol, message): @@ -188,7 +186,7 @@ class Server(object): if not application_instance.done(): application_instance.cancel() wait_for.append(application_instance) - logging.info("Killed %i pending application instances" % len(wait_for)) + logging.info("Killed %i pending application instances", len(wait_for)) # Make Twisted wait until they're all dead wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for)) wait_deferred.addErrback(lambda x: None) @@ -199,7 +197,7 @@ class Server(object): Called periodically to enforce timeout rules on all connections. Also checks pings at the same time. """ - for protocol in self.protocols: + for protocol in list(self.protocols): protocol.check_timeouts() reactor.callLater(2, self.timeout_checker) diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 192fee2..cae0de4 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -33,8 +33,6 @@ class WebSocketProtocol(WebSocketServerProtocol): self.socket_opened = time.time() self.last_data = time.time() try: - # Make new application instance - self.application_queue = self.server.create_application(self) # Sanitize and decode headers self.clean_headers = [] for name, value in request.headers.items(): @@ -60,41 +58,31 @@ class WebSocketProtocol(WebSocketServerProtocol): self.server.proxy_forwarded_port_header, self.client_addr ) - - # Make initial request info dict from request (we only have it here) + # Decode websocket subprotocol options + subprotocols = [] + for header, value in self.clean_headers: + if header == b'sec-websocket-protocol': + subprotocols = [x.strip() for x in self.unquote(value).split(",")] + # Make new application instance with scope self.path = request.path.encode("ascii") - self.connect_message = { - "type": "websocket.connect", + self.application_queue = self.server.create_application(self, { + "type": "websocket", "path": self.unquote(self.path), "headers": self.clean_headers, "query_string": self._raw_query_string, # Passed by HTTP protocol "client": self.client_addr, "server": self.server_addr, + "subprotocols": subprotocols, "order": 0, - } + }) except: # Exceptions here are not displayed right, just 500. # Turn them into an ERROR log. logger.error(traceback.format_exc()) raise - ws_protocol = None - for header, value in self.clean_headers: - if header == b'sec-websocket-protocol': - protocols = [x.strip() for x in self.unquote(value).split(",")] - for protocol in protocols: - if protocol in self.server.websocket_protocols: - ws_protocol = protocol - break - - # Work out what subprotocol we will accept, if any - if ws_protocol and ws_protocol in self.server.websocket_protocols: - self.protocol_to_accept = ws_protocol - else: - self.protocol_to_accept = None - # Send over the connect message - self.application_queue.put_nowait(self.connect_message) + self.application_queue.put_nowait({"type": "websocket.connect"}) self.server.log_action("websocket", "connecting", { "path": self.request.path, "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, @@ -154,7 +142,7 @@ class WebSocketProtocol(WebSocketServerProtocol): if "type" not in message: raise ValueError("Message has no type defined") if message["type"] == "websocket.accept": - self.serverAccept() + self.serverAccept(message.get("subprotocol", None)) elif message["type"] == "websocket.close": if self.state == self.STATE_CONNECTING: self.serverReject() @@ -174,11 +162,23 @@ class WebSocketProtocol(WebSocketServerProtocol): if message.get("text", None): self.serverSend(message["text"], False) - def serverAccept(self): + def handle_exception(self, exception): + """ + Called by the server when our application tracebacks + """ + if hasattr(self, "handshake_deferred"): + # If the handshake is still ongoing, we need to emit a HTTP error + # code rather than a WebSocket one. + self.handshake_deferred.errback(ConnectionDeny(code=500, reason="Internal server error")) + else: + self.sendCloseFrame(code=1011) + + def serverAccept(self, subprotocol=None): """ Called when we get a message saying to accept the connection. """ - self.handshake_deferred.callback(self.protocol_to_accept) + self.handshake_deferred.callback(subprotocol) + del self.handshake_deferred logger.debug("WebSocket %s accepted by application", self.client_addr) def serverReject(self): @@ -186,6 +186,7 @@ class WebSocketProtocol(WebSocketServerProtocol): Called when we get a message saying to reject the connection. """ self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied")) + del self.handshake_deferred self.server.discard_protocol(self) logger.debug("WebSocket %s rejected by application", self.client_addr) self.server.log_action("websocket", "rejected", {