diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 5e2681e..d44e9f2 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -308,33 +308,25 @@ class HTTPFactory(http.HTTPFactory): # Switch depending on current socket state protocol = self.reply_protocols[channel] # See if the message is valid - non_accept_keys = set(message.keys()) - {"accept"} - non_send_keys = set(message.keys()) - {"bytes", "text", "close"} - if non_accept_keys and non_send_keys: + unknown_keys = set(message.keys()) - {"bytes", "text", "close", "accept"} + if unknown_keys: raise ValueError( "Got invalid WebSocket reply message on %s - " - "contains unknown keys %s (looking for either {'accept'} or {'text', 'bytes', 'close'})" % ( + "contains unknown keys %s (looking for either {'accept', 'text', 'bytes', 'close'})" % ( channel, unknown_message_keys, ) ) - if "accept" in message: - if protocol.state != protocol.STATE_CONNECTING: - raise ValueError( - "Got invalid WebSocket connection reply message on %s - websocket is not in handshake phase" % ( - channel, - ) - ) - if message['accept']: - protocol.serverAccept() - else: + if message.get("accept", None) and protocol.state == protocol.STATE_CONNECTING: + protocol.serverAccept() + if message.get("bytes", None): + protocol.serverSend(message["bytes"], True) + if message.get("text", None): + protocol.serverSend(message["text"], False) + if message.get("close", False): + if protocol.state == protocol.STATE_CONNECTING: protocol.serverReject() - else: - if message.get("bytes", None): - protocol.serverSend(message["bytes"], True) - if message.get("text", None): - protocol.serverSend(message["text"], False) - if message.get("close", False): + else: protocol.serverClose() else: raise ValueError("Cannot dispatch message on channel %r" % channel) diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index db3e675..60f7ace 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -7,7 +7,7 @@ import traceback from six.moves.urllib_parse import unquote, urlencode from twisted.internet import defer -from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory +from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory, ConnectionDeny logger = logging.getLogger(__name__)