Make accept silently pass if already accepted

This commit is contained in:
Andrew Godwin 2016-10-05 16:00:11 -07:00
parent 8c637ff728
commit b537bed180
2 changed files with 13 additions and 21 deletions

View File

@ -308,33 +308,25 @@ class HTTPFactory(http.HTTPFactory):
# Switch depending on current socket state
protocol = self.reply_protocols[channel]
# See if the message is valid
non_accept_keys = set(message.keys()) - {"accept"}
non_send_keys = set(message.keys()) - {"bytes", "text", "close"}
if non_accept_keys and non_send_keys:
unknown_keys = set(message.keys()) - {"bytes", "text", "close", "accept"}
if unknown_keys:
raise ValueError(
"Got invalid WebSocket reply message on %s - "
"contains unknown keys %s (looking for either {'accept'} or {'text', 'bytes', 'close'})" % (
"contains unknown keys %s (looking for either {'accept', 'text', 'bytes', 'close'})" % (
channel,
unknown_message_keys,
)
)
if "accept" in message:
if protocol.state != protocol.STATE_CONNECTING:
raise ValueError(
"Got invalid WebSocket connection reply message on %s - websocket is not in handshake phase" % (
channel,
)
)
if message['accept']:
protocol.serverAccept()
else:
if message.get("accept", None) and protocol.state == protocol.STATE_CONNECTING:
protocol.serverAccept()
if message.get("bytes", None):
protocol.serverSend(message["bytes"], True)
if message.get("text", None):
protocol.serverSend(message["text"], False)
if message.get("close", False):
if protocol.state == protocol.STATE_CONNECTING:
protocol.serverReject()
else:
if message.get("bytes", None):
protocol.serverSend(message["bytes"], True)
if message.get("text", None):
protocol.serverSend(message["text"], False)
if message.get("close", False):
else:
protocol.serverClose()
else:
raise ValueError("Cannot dispatch message on channel %r" % channel)

View File

@ -7,7 +7,7 @@ import traceback
from six.moves.urllib_parse import unquote, urlencode
from twisted.internet import defer
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory, ConnectionDeny
logger = logging.getLogger(__name__)