mirror of
				https://github.com/django/daphne.git
				synced 2025-10-30 23:37:25 +03:00 
			
		
		
		
	Switch to new explicit WebSocket acceptance
This commit is contained in:
		
							parent
							
								
									1bdf2a7518
								
							
						
					
					
						commit
						685f3aed1e
					
				|  | @ -24,6 +24,18 @@ class AccessLogGenerator(object): | |||
|                 length=details['size'], | ||||
|             ) | ||||
|         # Websocket requests | ||||
|         elif protocol == "websocket" and action == "connecting": | ||||
|             self.write_entry( | ||||
|                 host=details['client'], | ||||
|                 date=datetime.datetime.now(), | ||||
|                 request="WSCONNECTING %(path)s" % details, | ||||
|             ) | ||||
|         elif protocol == "websocket" and action == "rejected": | ||||
|             self.write_entry( | ||||
|                 host=details['client'], | ||||
|                 date=datetime.datetime.now(), | ||||
|                 request="WSREJECT %(path)s" % details, | ||||
|             ) | ||||
|         elif protocol == "websocket" and action == "connected": | ||||
|             self.write_entry( | ||||
|                 host=details['client'], | ||||
|  |  | |||
|  | @ -281,12 +281,13 @@ class HTTPFactory(http.HTTPFactory): | |||
| 
 | ||||
|     protocol = HTTPProtocol | ||||
| 
 | ||||
|     def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path=""): | ||||
|     def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30): | ||||
|         http.HTTPFactory.__init__(self) | ||||
|         self.channel_layer = channel_layer | ||||
|         self.action_logger = action_logger | ||||
|         self.timeout = timeout | ||||
|         self.websocket_timeout = websocket_timeout | ||||
|         self.websocket_connect_timeout = websocket_connect_timeout | ||||
|         self.ping_interval = ping_interval | ||||
|         # We track all sub-protocols for response channel mapping | ||||
|         self.reply_protocols = {} | ||||
|  | @ -304,21 +305,37 @@ class HTTPFactory(http.HTTPFactory): | |||
|         if channel.startswith("http") and isinstance(self.reply_protocols[channel], WebRequest): | ||||
|             self.reply_protocols[channel].serverResponse(message) | ||||
|         elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol): | ||||
|             # Ensure the message is a valid WebSocket one | ||||
|             unknown_message_keys = set(message.keys()) - {"bytes", "text", "close"} | ||||
|             if unknown_message_keys: | ||||
|             # Switch depending on current socket state | ||||
|             protocol = self.reply_protocols[channel] | ||||
|             # See if the message is valid | ||||
|             non_accept_keys = set(message.keys()) - {"accept"} | ||||
|             non_send_keys = set(message.keys()) - {"bytes", "text", "close"} | ||||
|             if non_accept_keys and non_send_keys: | ||||
|                 raise ValueError( | ||||
|                     "Got invalid WebSocket reply message on %s - contains unknown keys %s" % ( | ||||
|                     "Got invalid WebSocket reply message on %s - " | ||||
|                     "contains unknown keys %s (looking for either {'accept'} or {'text', 'bytes', 'close'})" % ( | ||||
|                         channel, | ||||
|                         unknown_message_keys, | ||||
|                     ) | ||||
|                 ) | ||||
|             if "accept" in message: | ||||
|                 if protocol.state != protocol.STATE_CONNECTING: | ||||
|                     raise ValueError( | ||||
|                         "Got invalid WebSocket connection reply message on %s - websocket is not in handshake phase" % ( | ||||
|                             channel, | ||||
|                         ) | ||||
|                     ) | ||||
|                 if message['accept']: | ||||
|                     protocol.serverAccept() | ||||
|                 else: | ||||
|                     protocol.serverReject() | ||||
|             else: | ||||
|                 if message.get("bytes", None): | ||||
|                 self.reply_protocols[channel].serverSend(message["bytes"], True) | ||||
|                     protocol.serverSend(message["bytes"], True) | ||||
|                 if message.get("text", None): | ||||
|                 self.reply_protocols[channel].serverSend(message["text"], False) | ||||
|                     protocol.serverSend(message["text"], False) | ||||
|                 if message.get("close", False): | ||||
|                 self.reply_protocols[channel].serverClose() | ||||
|                     protocol.serverClose() | ||||
|         else: | ||||
|             raise ValueError("Cannot dispatch message on channel %r" % channel) | ||||
| 
 | ||||
|  |  | |||
|  | @ -5,6 +5,7 @@ import six | |||
| import time | ||||
| import traceback | ||||
| from six.moves.urllib_parse import unquote, urlencode | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory | ||||
| 
 | ||||
|  | @ -27,6 +28,7 @@ class WebSocketProtocol(WebSocketServerProtocol): | |||
|     def onConnect(self, request): | ||||
|         self.request = request | ||||
|         self.packets_received = 0 | ||||
|         self.protocol_to_accept = None | ||||
|         self.socket_opened = time.time() | ||||
|         self.last_data = time.time() | ||||
|         try: | ||||
|  | @ -78,8 +80,31 @@ class WebSocketProtocol(WebSocketServerProtocol): | |||
|                         ws_protocol = protocol | ||||
|                         break | ||||
| 
 | ||||
|         # Work out what subprotocol we will accept, if any | ||||
|         if ws_protocol and ws_protocol in self.factory.protocols: | ||||
|             return ws_protocol | ||||
|             self.protocol_to_accept = ws_protocol | ||||
|         else: | ||||
|             self.protocol_to_accept = None | ||||
| 
 | ||||
|         # Send over the connect message | ||||
|         try: | ||||
|             self.channel_layer.send("websocket.connect", self.request_info) | ||||
|         except self.channel_layer.ChannelFull: | ||||
|             # You have to consume websocket.connect according to the spec, | ||||
|             # so drop the connection. | ||||
|             self.muted = True | ||||
|             logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel) | ||||
|             # Send code 1013 "try again later" with close. | ||||
|             raise ConnectionDeny(code=503, reason="Connection queue at capacity") | ||||
|         else: | ||||
|             self.factory.log_action("websocket", "connecting", { | ||||
|                 "path": self.request.path, | ||||
|                 "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, | ||||
|             }) | ||||
| 
 | ||||
|         # Make a deferred and return it - we'll either call it or err it later on | ||||
|         self.handshake_deferred = defer.Deferred() | ||||
|         return self.handshake_deferred | ||||
| 
 | ||||
|     @classmethod | ||||
|     def unquote(cls, value): | ||||
|  | @ -93,17 +118,7 @@ class WebSocketProtocol(WebSocketServerProtocol): | |||
| 
 | ||||
|     def onOpen(self): | ||||
|         # Send news that this channel is open | ||||
|         logger.debug("WebSocket open for %s", self.reply_channel) | ||||
|         try: | ||||
|             self.channel_layer.send("websocket.connect", self.request_info) | ||||
|         except self.channel_layer.ChannelFull: | ||||
|             # You have to consume websocket.connect according to the spec, | ||||
|             # so drop the connection. | ||||
|             self.muted = True | ||||
|             logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel) | ||||
|             # Send code 1013 "try again later" with close. | ||||
|             self.sendCloseFrame(code=1013, isReply=False) | ||||
|         else: | ||||
|         logger.debug("WebSocket %s open and established", self.reply_channel) | ||||
|         self.factory.log_action("websocket", "connected", { | ||||
|             "path": self.request.path, | ||||
|             "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, | ||||
|  | @ -140,10 +155,31 @@ class WebSocketProtocol(WebSocketServerProtocol): | |||
|             # Send code 1013 "try again later" with close. | ||||
|             self.sendCloseFrame(code=1013, isReply=False) | ||||
| 
 | ||||
|     def serverAccept(self): | ||||
|         """ | ||||
|         Called when we get a message saying to accept the connection. | ||||
|         """ | ||||
|         self.handshake_deferred.callback(self.protocol_to_accept) | ||||
|         logger.debug("WebSocket %s accepted by application", self.reply_channel) | ||||
| 
 | ||||
|     def serverReject(self): | ||||
|         """ | ||||
|         Called when we get a message saying to accept the connection. | ||||
|         """ | ||||
|         self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied")) | ||||
|         self.cleanup() | ||||
|         logger.debug("WebSocket %s rejected by application", self.reply_channel) | ||||
|         self.factory.log_action("websocket", "rejected", { | ||||
|             "path": self.request.path, | ||||
|             "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, | ||||
|         }) | ||||
| 
 | ||||
|     def serverSend(self, content, binary=False): | ||||
|         """ | ||||
|         Server-side channel message to send a message. | ||||
|         """ | ||||
|         if self.state == self.STATE_CONNECTING: | ||||
|             self.serverAccept() | ||||
|         self.last_data = time.time() | ||||
|         logger.debug("Sent WebSocket packet to client for %s", self.reply_channel) | ||||
|         if binary: | ||||
|  | @ -158,9 +194,9 @@ class WebSocketProtocol(WebSocketServerProtocol): | |||
|         self.sendClose() | ||||
| 
 | ||||
|     def onClose(self, wasClean, code, reason): | ||||
|         self.cleanup() | ||||
|         if hasattr(self, "reply_channel"): | ||||
|             logger.debug("WebSocket closed for %s", self.reply_channel) | ||||
|             del self.factory.reply_protocols[self.reply_channel] | ||||
|             try: | ||||
|                 if not self.muted: | ||||
|                     self.channel_layer.send("websocket.disconnect", { | ||||
|  | @ -178,6 +214,13 @@ class WebSocketProtocol(WebSocketServerProtocol): | |||
|         else: | ||||
|             logger.debug("WebSocket closed before handshake established") | ||||
| 
 | ||||
|     def cleanup(self): | ||||
|         """ | ||||
|         Call to clean up this socket after it's closed. | ||||
|         """ | ||||
|         if hasattr(self, "reply_channel"): | ||||
|             del self.factory.reply_protocols[self.reply_channel] | ||||
| 
 | ||||
|     def duration(self): | ||||
|         """ | ||||
|         Returns the time since the socket was opened | ||||
|  | @ -186,8 +229,13 @@ class WebSocketProtocol(WebSocketServerProtocol): | |||
| 
 | ||||
|     def check_ping(self): | ||||
|         """ | ||||
|         Checks to see if we should send a keepalive ping. | ||||
|         Checks to see if we should send a keepalive ping/deny socket connection | ||||
|         """ | ||||
|         # If we're still connecting, deny the connection | ||||
|         if self.state == self.STATE_CONNECTING: | ||||
|             if self.duration() > self.main_factory.websocket_connect_timeout: | ||||
|                 self.serverReject() | ||||
|         elif self.state == self.STATE_OPEN: | ||||
|             if (time.time() - self.last_data) > self.main_factory.ping_interval: | ||||
|                 self._sendAutoPing() | ||||
|                 self.last_data = time.time() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user