Use code when rejecting connection

Fixed #486
This commit is contained in:
Dave Johansen 2023-10-05 18:53:01 -06:00 committed by GitHub
parent 2d4dcbf149
commit 7234cb1638
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -185,7 +185,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.serverAccept(message.get("subprotocol", None)) self.serverAccept(message.get("subprotocol", None))
elif message["type"] == "websocket.close": elif message["type"] == "websocket.close":
if self.state == self.STATE_CONNECTING: if self.state == self.STATE_CONNECTING:
self.serverReject() self.serverReject(code=message.get("code", 403))
else: else:
self.serverClose(code=message.get("code", None)) self.serverClose(code=message.get("code", None))
elif message["type"] == "websocket.send": elif message["type"] == "websocket.send":
@ -222,12 +222,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
del self.handshake_deferred del self.handshake_deferred
logger.debug("WebSocket %s accepted by application", self.client_addr) logger.debug("WebSocket %s accepted by application", self.client_addr)
def serverReject(self): def serverReject(self, code):
""" """
Called when we get a message saying to reject the connection. Called when we get a message saying to reject the connection.
""" """
self.handshake_deferred.errback( self.handshake_deferred.errback(
ConnectionDeny(code=403, reason="Access denied") ConnectionDeny(code=code, reason="Connection closed")
) )
del self.handshake_deferred del self.handshake_deferred
self.server.protocol_disconnected(self) self.server.protocol_disconnected(self)
@ -284,7 +284,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
# If we're still connecting, deny the connection # If we're still connecting, deny the connection
if self.state == self.STATE_CONNECTING: if self.state == self.STATE_CONNECTING:
if self.duration() > self.server.websocket_connect_timeout: if self.duration() > self.server.websocket_connect_timeout:
self.serverReject() self.serverReject(408)
elif self.state == self.STATE_OPEN: elif self.state == self.STATE_OPEN:
if (time.time() - self.last_ping) > self.server.ping_interval: if (time.time() - self.last_ping) > self.server.ping_interval:
self._sendAutoPing() self._sendAutoPing()