Make accept silently pass if already accepted

This commit is contained in:
Andrew Godwin 2016-10-05 16:00:11 -07:00
parent 8c637ff728
commit b537bed180
2 changed files with 13 additions and 21 deletions

View File

@ -308,33 +308,25 @@ class HTTPFactory(http.HTTPFactory):
# Switch depending on current socket state # Switch depending on current socket state
protocol = self.reply_protocols[channel] protocol = self.reply_protocols[channel]
# See if the message is valid # See if the message is valid
non_accept_keys = set(message.keys()) - {"accept"} unknown_keys = set(message.keys()) - {"bytes", "text", "close", "accept"}
non_send_keys = set(message.keys()) - {"bytes", "text", "close"} if unknown_keys:
if non_accept_keys and non_send_keys:
raise ValueError( raise ValueError(
"Got invalid WebSocket reply message on %s - " "Got invalid WebSocket reply message on %s - "
"contains unknown keys %s (looking for either {'accept'} or {'text', 'bytes', 'close'})" % ( "contains unknown keys %s (looking for either {'accept', 'text', 'bytes', 'close'})" % (
channel, channel,
unknown_message_keys, unknown_message_keys,
) )
) )
if "accept" in message: if message.get("accept", None) and protocol.state == protocol.STATE_CONNECTING:
if protocol.state != protocol.STATE_CONNECTING: protocol.serverAccept()
raise ValueError( if message.get("bytes", None):
"Got invalid WebSocket connection reply message on %s - websocket is not in handshake phase" % ( protocol.serverSend(message["bytes"], True)
channel, if message.get("text", None):
) protocol.serverSend(message["text"], False)
) if message.get("close", False):
if message['accept']: if protocol.state == protocol.STATE_CONNECTING:
protocol.serverAccept()
else:
protocol.serverReject() protocol.serverReject()
else: else:
if message.get("bytes", None):
protocol.serverSend(message["bytes"], True)
if message.get("text", None):
protocol.serverSend(message["text"], False)
if message.get("close", False):
protocol.serverClose() protocol.serverClose()
else: else:
raise ValueError("Cannot dispatch message on channel %r" % channel) raise ValueError("Cannot dispatch message on channel %r" % channel)

View File

@ -7,7 +7,7 @@ import traceback
from six.moves.urllib_parse import unquote, urlencode from six.moves.urllib_parse import unquote, urlencode
from twisted.internet import defer from twisted.internet import defer
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory, ConnectionDeny
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)