mirror of
				https://github.com/django/daphne.git
				synced 2025-10-30 23:37:25 +03:00 
			
		
		
		
	Switch to new explicit WebSocket acceptance
This commit is contained in:
		
							parent
							
								
									1bdf2a7518
								
							
						
					
					
						commit
						685f3aed1e
					
				|  | @ -24,6 +24,18 @@ class AccessLogGenerator(object): | ||||||
|                 length=details['size'], |                 length=details['size'], | ||||||
|             ) |             ) | ||||||
|         # Websocket requests |         # Websocket requests | ||||||
|  |         elif protocol == "websocket" and action == "connecting": | ||||||
|  |             self.write_entry( | ||||||
|  |                 host=details['client'], | ||||||
|  |                 date=datetime.datetime.now(), | ||||||
|  |                 request="WSCONNECTING %(path)s" % details, | ||||||
|  |             ) | ||||||
|  |         elif protocol == "websocket" and action == "rejected": | ||||||
|  |             self.write_entry( | ||||||
|  |                 host=details['client'], | ||||||
|  |                 date=datetime.datetime.now(), | ||||||
|  |                 request="WSREJECT %(path)s" % details, | ||||||
|  |             ) | ||||||
|         elif protocol == "websocket" and action == "connected": |         elif protocol == "websocket" and action == "connected": | ||||||
|             self.write_entry( |             self.write_entry( | ||||||
|                 host=details['client'], |                 host=details['client'], | ||||||
|  |  | ||||||
|  | @ -281,12 +281,13 @@ class HTTPFactory(http.HTTPFactory): | ||||||
| 
 | 
 | ||||||
|     protocol = HTTPProtocol |     protocol = HTTPProtocol | ||||||
| 
 | 
 | ||||||
|     def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path=""): |     def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30): | ||||||
|         http.HTTPFactory.__init__(self) |         http.HTTPFactory.__init__(self) | ||||||
|         self.channel_layer = channel_layer |         self.channel_layer = channel_layer | ||||||
|         self.action_logger = action_logger |         self.action_logger = action_logger | ||||||
|         self.timeout = timeout |         self.timeout = timeout | ||||||
|         self.websocket_timeout = websocket_timeout |         self.websocket_timeout = websocket_timeout | ||||||
|  |         self.websocket_connect_timeout = websocket_connect_timeout | ||||||
|         self.ping_interval = ping_interval |         self.ping_interval = ping_interval | ||||||
|         # We track all sub-protocols for response channel mapping |         # We track all sub-protocols for response channel mapping | ||||||
|         self.reply_protocols = {} |         self.reply_protocols = {} | ||||||
|  | @ -304,21 +305,37 @@ class HTTPFactory(http.HTTPFactory): | ||||||
|         if channel.startswith("http") and isinstance(self.reply_protocols[channel], WebRequest): |         if channel.startswith("http") and isinstance(self.reply_protocols[channel], WebRequest): | ||||||
|             self.reply_protocols[channel].serverResponse(message) |             self.reply_protocols[channel].serverResponse(message) | ||||||
|         elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol): |         elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol): | ||||||
|             # Ensure the message is a valid WebSocket one |             # Switch depending on current socket state | ||||||
|             unknown_message_keys = set(message.keys()) - {"bytes", "text", "close"} |             protocol = self.reply_protocols[channel] | ||||||
|             if unknown_message_keys: |             # See if the message is valid | ||||||
|  |             non_accept_keys = set(message.keys()) - {"accept"} | ||||||
|  |             non_send_keys = set(message.keys()) - {"bytes", "text", "close"} | ||||||
|  |             if non_accept_keys and non_send_keys: | ||||||
|                 raise ValueError( |                 raise ValueError( | ||||||
|                     "Got invalid WebSocket reply message on %s - contains unknown keys %s" % ( |                     "Got invalid WebSocket reply message on %s - " | ||||||
|  |                     "contains unknown keys %s (looking for either {'accept'} or {'text', 'bytes', 'close'})" % ( | ||||||
|                         channel, |                         channel, | ||||||
|                         unknown_message_keys, |                         unknown_message_keys, | ||||||
|                     ) |                     ) | ||||||
|                 ) |                 ) | ||||||
|             if message.get("bytes", None): |             if "accept" in message: | ||||||
|                 self.reply_protocols[channel].serverSend(message["bytes"], True) |                 if protocol.state != protocol.STATE_CONNECTING: | ||||||
|             if message.get("text", None): |                     raise ValueError( | ||||||
|                 self.reply_protocols[channel].serverSend(message["text"], False) |                         "Got invalid WebSocket connection reply message on %s - websocket is not in handshake phase" % ( | ||||||
|             if message.get("close", False): |                             channel, | ||||||
|                 self.reply_protocols[channel].serverClose() |                         ) | ||||||
|  |                     ) | ||||||
|  |                 if message['accept']: | ||||||
|  |                     protocol.serverAccept() | ||||||
|  |                 else: | ||||||
|  |                     protocol.serverReject() | ||||||
|  |             else: | ||||||
|  |                 if message.get("bytes", None): | ||||||
|  |                     protocol.serverSend(message["bytes"], True) | ||||||
|  |                 if message.get("text", None): | ||||||
|  |                     protocol.serverSend(message["text"], False) | ||||||
|  |                 if message.get("close", False): | ||||||
|  |                     protocol.serverClose() | ||||||
|         else: |         else: | ||||||
|             raise ValueError("Cannot dispatch message on channel %r" % channel) |             raise ValueError("Cannot dispatch message on channel %r" % channel) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ import six | ||||||
| import time | import time | ||||||
| import traceback | import traceback | ||||||
| from six.moves.urllib_parse import unquote, urlencode | from six.moves.urllib_parse import unquote, urlencode | ||||||
|  | from twisted.internet import defer | ||||||
| 
 | 
 | ||||||
| from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory | from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory | ||||||
| 
 | 
 | ||||||
|  | @ -27,6 +28,7 @@ class WebSocketProtocol(WebSocketServerProtocol): | ||||||
|     def onConnect(self, request): |     def onConnect(self, request): | ||||||
|         self.request = request |         self.request = request | ||||||
|         self.packets_received = 0 |         self.packets_received = 0 | ||||||
|  |         self.protocol_to_accept = None | ||||||
|         self.socket_opened = time.time() |         self.socket_opened = time.time() | ||||||
|         self.last_data = time.time() |         self.last_data = time.time() | ||||||
|         try: |         try: | ||||||
|  | @ -78,8 +80,31 @@ class WebSocketProtocol(WebSocketServerProtocol): | ||||||
|                         ws_protocol = protocol |                         ws_protocol = protocol | ||||||
|                         break |                         break | ||||||
| 
 | 
 | ||||||
|  |         # Work out what subprotocol we will accept, if any | ||||||
|         if ws_protocol and ws_protocol in self.factory.protocols: |         if ws_protocol and ws_protocol in self.factory.protocols: | ||||||
|             return ws_protocol |             self.protocol_to_accept = ws_protocol | ||||||
|  |         else: | ||||||
|  |             self.protocol_to_accept = None | ||||||
|  | 
 | ||||||
|  |         # Send over the connect message | ||||||
|  |         try: | ||||||
|  |             self.channel_layer.send("websocket.connect", self.request_info) | ||||||
|  |         except self.channel_layer.ChannelFull: | ||||||
|  |             # You have to consume websocket.connect according to the spec, | ||||||
|  |             # so drop the connection. | ||||||
|  |             self.muted = True | ||||||
|  |             logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel) | ||||||
|  |             # Send code 1013 "try again later" with close. | ||||||
|  |             raise ConnectionDeny(code=503, reason="Connection queue at capacity") | ||||||
|  |         else: | ||||||
|  |             self.factory.log_action("websocket", "connecting", { | ||||||
|  |                 "path": self.request.path, | ||||||
|  |                 "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, | ||||||
|  |             }) | ||||||
|  | 
 | ||||||
|  |         # Make a deferred and return it - we'll either call it or err it later on | ||||||
|  |         self.handshake_deferred = defer.Deferred() | ||||||
|  |         return self.handshake_deferred | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def unquote(cls, value): |     def unquote(cls, value): | ||||||
|  | @ -93,21 +118,11 @@ class WebSocketProtocol(WebSocketServerProtocol): | ||||||
| 
 | 
 | ||||||
|     def onOpen(self): |     def onOpen(self): | ||||||
|         # Send news that this channel is open |         # Send news that this channel is open | ||||||
|         logger.debug("WebSocket open for %s", self.reply_channel) |         logger.debug("WebSocket %s open and established", self.reply_channel) | ||||||
|         try: |         self.factory.log_action("websocket", "connected", { | ||||||
|             self.channel_layer.send("websocket.connect", self.request_info) |             "path": self.request.path, | ||||||
|         except self.channel_layer.ChannelFull: |             "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, | ||||||
|             # You have to consume websocket.connect according to the spec, |         }) | ||||||
|             # so drop the connection. |  | ||||||
|             self.muted = True |  | ||||||
|             logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel) |  | ||||||
|             # Send code 1013 "try again later" with close. |  | ||||||
|             self.sendCloseFrame(code=1013, isReply=False) |  | ||||||
|         else: |  | ||||||
|             self.factory.log_action("websocket", "connected", { |  | ||||||
|                 "path": self.request.path, |  | ||||||
|                 "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, |  | ||||||
|             }) |  | ||||||
| 
 | 
 | ||||||
|     def onMessage(self, payload, isBinary): |     def onMessage(self, payload, isBinary): | ||||||
|         # If we're muted, do nothing. |         # If we're muted, do nothing. | ||||||
|  | @ -140,10 +155,31 @@ class WebSocketProtocol(WebSocketServerProtocol): | ||||||
|             # Send code 1013 "try again later" with close. |             # Send code 1013 "try again later" with close. | ||||||
|             self.sendCloseFrame(code=1013, isReply=False) |             self.sendCloseFrame(code=1013, isReply=False) | ||||||
| 
 | 
 | ||||||
|  |     def serverAccept(self): | ||||||
|  |         """ | ||||||
|  |         Called when we get a message saying to accept the connection. | ||||||
|  |         """ | ||||||
|  |         self.handshake_deferred.callback(self.protocol_to_accept) | ||||||
|  |         logger.debug("WebSocket %s accepted by application", self.reply_channel) | ||||||
|  | 
 | ||||||
|  |     def serverReject(self): | ||||||
|  |         """ | ||||||
|  |         Called when we get a message saying to accept the connection. | ||||||
|  |         """ | ||||||
|  |         self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied")) | ||||||
|  |         self.cleanup() | ||||||
|  |         logger.debug("WebSocket %s rejected by application", self.reply_channel) | ||||||
|  |         self.factory.log_action("websocket", "rejected", { | ||||||
|  |             "path": self.request.path, | ||||||
|  |             "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, | ||||||
|  |         }) | ||||||
|  | 
 | ||||||
|     def serverSend(self, content, binary=False): |     def serverSend(self, content, binary=False): | ||||||
|         """ |         """ | ||||||
|         Server-side channel message to send a message. |         Server-side channel message to send a message. | ||||||
|         """ |         """ | ||||||
|  |         if self.state == self.STATE_CONNECTING: | ||||||
|  |             self.serverAccept() | ||||||
|         self.last_data = time.time() |         self.last_data = time.time() | ||||||
|         logger.debug("Sent WebSocket packet to client for %s", self.reply_channel) |         logger.debug("Sent WebSocket packet to client for %s", self.reply_channel) | ||||||
|         if binary: |         if binary: | ||||||
|  | @ -158,9 +194,9 @@ class WebSocketProtocol(WebSocketServerProtocol): | ||||||
|         self.sendClose() |         self.sendClose() | ||||||
| 
 | 
 | ||||||
|     def onClose(self, wasClean, code, reason): |     def onClose(self, wasClean, code, reason): | ||||||
|  |         self.cleanup() | ||||||
|         if hasattr(self, "reply_channel"): |         if hasattr(self, "reply_channel"): | ||||||
|             logger.debug("WebSocket closed for %s", self.reply_channel) |             logger.debug("WebSocket closed for %s", self.reply_channel) | ||||||
|             del self.factory.reply_protocols[self.reply_channel] |  | ||||||
|             try: |             try: | ||||||
|                 if not self.muted: |                 if not self.muted: | ||||||
|                     self.channel_layer.send("websocket.disconnect", { |                     self.channel_layer.send("websocket.disconnect", { | ||||||
|  | @ -178,6 +214,13 @@ class WebSocketProtocol(WebSocketServerProtocol): | ||||||
|         else: |         else: | ||||||
|             logger.debug("WebSocket closed before handshake established") |             logger.debug("WebSocket closed before handshake established") | ||||||
| 
 | 
 | ||||||
|  |     def cleanup(self): | ||||||
|  |         """ | ||||||
|  |         Call to clean up this socket after it's closed. | ||||||
|  |         """ | ||||||
|  |         if hasattr(self, "reply_channel"): | ||||||
|  |             del self.factory.reply_protocols[self.reply_channel] | ||||||
|  | 
 | ||||||
|     def duration(self): |     def duration(self): | ||||||
|         """ |         """ | ||||||
|         Returns the time since the socket was opened |         Returns the time since the socket was opened | ||||||
|  | @ -186,11 +229,16 @@ class WebSocketProtocol(WebSocketServerProtocol): | ||||||
| 
 | 
 | ||||||
|     def check_ping(self): |     def check_ping(self): | ||||||
|         """ |         """ | ||||||
|         Checks to see if we should send a keepalive ping. |         Checks to see if we should send a keepalive ping/deny socket connection | ||||||
|         """ |         """ | ||||||
|         if (time.time() - self.last_data) > self.main_factory.ping_interval: |         # If we're still connecting, deny the connection | ||||||
|             self._sendAutoPing() |         if self.state == self.STATE_CONNECTING: | ||||||
|             self.last_data = time.time() |             if self.duration() > self.main_factory.websocket_connect_timeout: | ||||||
|  |                 self.serverReject() | ||||||
|  |         elif self.state == self.STATE_OPEN: | ||||||
|  |             if (time.time() - self.last_data) > self.main_factory.ping_interval: | ||||||
|  |                 self._sendAutoPing() | ||||||
|  |                 self.last_data = time.time() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class WebSocketFactory(WebSocketServerFactory): | class WebSocketFactory(WebSocketServerFactory): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user