diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index bb67d8f..215011d 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -17,6 +17,9 @@ class WebSocketProtocol(WebSocketServerProtocol): the websocket channels. """ + # If we should send no more messages (e.g. we error-closed the socket) + muted = False + def set_main_factory(self, main_factory): self.main_factory = main_factory self.channel_layer = self.main_factory.channel_layer @@ -94,14 +97,21 @@ class WebSocketProtocol(WebSocketServerProtocol): try: self.channel_layer.send("websocket.connect", self.request_info) except self.channel_layer.ChannelFull: - # We don't drop the connection here as you don't _have_ to consume websocket.connect - pass - self.factory.log_action("websocket", "connected", { - "path": self.request.path, - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - }) + # You have to consume websocket.connect according to the spec, + # so drop the connection. + self.muted = True + logger.warn("WebSocket force closed for %s due to backpressure", self.reply_channel) + self.sendClose() + else: + self.factory.log_action("websocket", "connected", { + "path": self.request.path, + "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, + }) def onMessage(self, payload, isBinary): + # If we're muted, do nothing. + if self.muted: + return logger.debug("WebSocket incoming packet on %s", self.reply_channel) self.packets_received += 1 self.last_data = time.time() @@ -147,11 +157,12 @@ class WebSocketProtocol(WebSocketServerProtocol): logger.debug("WebSocket closed for %s", self.reply_channel) del self.factory.reply_protocols[self.reply_channel] try: - self.channel_layer.send("websocket.disconnect", { - "reply_channel": self.reply_channel, - "path": self.unquote(self.path), - "order": self.packets_received + 1, - }) + if not self.muted: + self.channel_layer.send("websocket.disconnect", { + "reply_channel": self.reply_channel, + "path": self.unquote(self.path), + "order": self.packets_received + 1, + }) except self.channel_layer.ChannelFull: pass self.factory.log_action("websocket", "disconnected", {