diff --git a/daphne/cli.py b/daphne/cli.py index a35f51e..71faf0c 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -131,12 +131,8 @@ class CommandLineInterface(object): default=2, ) self.parser.add_argument( - 'channel_layer', - help='The ASGI channel layer instance to use as path.to.module:instance.path', - ) - self.parser.add_argument( - 'consumer', - help='The consumer to dispatch to as path.to.module:instance.path', + 'application', + help='The application to dispatch to as path.to.module:instance.path', ) self.server = None @@ -161,6 +157,7 @@ class CommandLineInterface(object): 0: logging.WARN, 1: logging.INFO, 2: logging.DEBUG, + 3: logging.DEBUG, # Also turns on asyncio debug }[args.verbosity], format="%(asctime)-15s %(levelname)-8s %(message)s", ) @@ -173,10 +170,9 @@ class CommandLineInterface(object): access_log_stream = open(args.access_log, "a", 1) elif args.verbosity >= 1: access_log_stream = sys.stdout - # Import application and channel_layer + # Import application sys.path.insert(0, ".") - channel_layer = import_by_path(args.channel_layer) - consumer = import_by_path(args.consumer) + application = import_by_path(args.application) # Set up port/host bindings if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]): # no advanced binding options passed, patch in defaults @@ -198,12 +194,11 @@ class CommandLineInterface(object): ) # Start the server logger.info( - 'Starting server at %s, channel layer %s.' % - (', '.join(endpoints), args.channel_layer) + 'Starting server at %s' % + (', '.join(endpoints), ) ) self.server = Server( - channel_layer=channel_layer, - consumer=consumer, + application=application, endpoints=endpoints, http_timeout=args.http_timeout, ping_interval=args.ping_interval, diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 0e87b49..6b4ece6 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -10,6 +10,7 @@ import traceback from zope.interface import implementer from six.moves.urllib_parse import unquote, unquote_plus +from twisted.internet.defer import ensureDeferred from twisted.internet.interfaces import IProtocolNegotiationFactory from twisted.protocols.policies import ProtocolWrapper from twisted.web import http @@ -29,7 +30,7 @@ class WebRequest(http.Request): GET and POST out. """ - consumer_type = "http" + application_type = "http" error_template = """ @@ -55,7 +56,7 @@ class WebRequest(http.Request): http.Request.__init__(self, *args, **kwargs) # Easy server link self.server = self.channel.factory.server - self.consumer_channel = None + self.application_queue = None self._response_started = False self.server.add_protocol(self) except Exception: @@ -137,8 +138,8 @@ class WebRequest(http.Request): # Boring old HTTP. else: - # Create consumer to handle this connection - self.consumer_channel = self.server.create_consumer(self) + # Create application to handle this connection + self.application_queue = self.server.create_application(self) # Sanitize and decode headers, potentially extracting root path self.clean_headers = [] self.root_path = self.server.root_path @@ -151,11 +152,10 @@ class WebRequest(http.Request): self.root_path = self.unquote(value) else: self.clean_headers.append((name.lower(), value)) - logger.debug("HTTP %s request for %s", self.method, self.consumer_channel) + logger.debug("HTTP %s request for %s", self.method, self.client_addr) self.content.seek(0, 0) - # Run consumer - self.server.handle_message( - self.consumer_channel, + # Run application against request + self.application_queue.put_nowait( { "type": "http.request", # TODO: Correctly say if it's 1.1 or 1.0 @@ -173,13 +173,13 @@ class WebRequest(http.Request): ) except Exception: logger.error(traceback.format_exc()) - self.basic_error(500, b"Internal Server Error", "HTTP processing error") + self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error") def connectionLost(self, reason): """ Cleans up reply channel on close. """ - if self.consumer_channel: + if self.application_queue: self.send_disconnect() logger.debug("HTTP disconnect for %s", self.client_addr) http.Request.connectionLost(self, reason) @@ -189,7 +189,7 @@ class WebRequest(http.Request): """ Cleans up reply channel on close. """ - if self.consumer_channel: + if self.application_queue: self.send_disconnect() logger.debug("HTTP close for %s", self.client_addr) http.Request.finish(self) @@ -201,6 +201,8 @@ class WebRequest(http.Request): """ Handles a reply from the client """ + if "type" not in message: + raise ValueError("Message has no type defined") if message['type'] == "http.response": if self._response_started: raise ValueError("HTTP response has already been started") @@ -216,9 +218,9 @@ class WebRequest(http.Request): header = header.encode("latin1") self.responseHeaders.addRawHeader(header, value) logger.debug("HTTP %s response started for %s", message['status'], self.client_addr) - elif message['type'] == "http.response.content": + elif message['type'] == "http.response.hunk": if not self._response_started: - raise ValueError("HTTP response has not yet been started") + raise ValueError("HTTP response has not yet been started but got %s" % message['type']) else: raise ValueError("Cannot handle message type %s!" % message['type']) @@ -243,6 +245,12 @@ class WebRequest(http.Request): else: logger.debug("HTTP response chunk for %s", self.client_addr) + def handle_exception(self, exception): + """ + Called by the server when our application tracebacks + """ + self.basic_error(500, b"Internal Server Error", "Exception inside application.") + def check_timeouts(self): """ Called periodically to see if we should timeout something @@ -276,8 +284,7 @@ class WebRequest(http.Request): """ # If we don't yet have a path, then don't send as we never opened. if self.path: - self.server.handle_message( - self.consumer_channel, + self.application_queue.put_nowait( { "type": "http.disconnect", }, @@ -295,7 +302,8 @@ class WebRequest(http.Request): """ Responds with a server-level error page (very basic) """ - self.serverResponse({ + self.handle_reply({ + "type": "http.response", "status": status, "status_text": status_text, "headers": [ @@ -340,83 +348,6 @@ class HTTPFactory(http.HTTPFactory): logger.error("Cannot build protocol: %s" % traceback.format_exc()) raise - def make_send_channel(self): - """ - Makes a new send channel for a protocol with our process prefix. - """ - protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10)) - return self.send_channel + protocol_id - - def dispatch_reply(self, channel, message): - # If we don't know about the channel, ignore it (likely a channel we - # used to have that's now in a group). - # TODO: Find a better way of alerting people when this happens so - # they can do more cleanup, that's not an error. - if channel not in self.reply_protocols: - logger.debug("Message on unknown channel %r - discarding" % channel) - return - - if isinstance(self.reply_protocols[channel], WebRequest): - self.reply_protocols[channel].serverResponse(message) - elif isinstance(self.reply_protocols[channel], WebSocketProtocol): - # Switch depending on current socket state - protocol = self.reply_protocols[channel] - # See if the message is valid - unknown_keys = set(message.keys()) - {"bytes", "text", "close", "accept"} - if unknown_keys: - raise ValueError( - "Got invalid WebSocket reply message on %s - " - "contains unknown keys %s (looking for either {'accept', 'text', 'bytes', 'close'})" % ( - channel, - unknown_keys, - ) - ) - # Accepts allow bytes/text afterwards - if message.get("accept", None) and protocol.state == protocol.STATE_CONNECTING: - protocol.serverAccept() - # Rejections must be the only thing - if message.get("accept", None) == False and protocol.state == protocol.STATE_CONNECTING: - protocol.serverReject() - return - # You're only allowed one of bytes or text - if message.get("bytes", None) and message.get("text", None): - raise ValueError( - "Got invalid WebSocket reply message on %s - contains both bytes and text keys" % ( - channel, - ) - ) - if message.get("bytes", None): - protocol.serverSend(message["bytes"], True) - if message.get("text", None): - protocol.serverSend(message["text"], False) - - closing_code = message.get("close", False) - if closing_code: - if protocol.state == protocol.STATE_CONNECTING: - protocol.serverReject() - else: - protocol.serverClose(code=closing_code) - else: - raise ValueError("Unknown protocol class") - - def check_timeouts(self): - """ - Runs through all HTTP protocol instances and times them out if they've - taken too long (and so their message is probably expired) - """ - for protocol in list(self.reply_protocols.values()): - # Web timeout checking - if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout: - protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.") - # WebSocket timeout checking and keepalive ping sending - elif isinstance(protocol, WebSocketProtocol): - # Timeout check - if protocol.duration() > self.websocket_timeout and self.websocket_timeout >= 0: - protocol.serverClose() - # Ping check - else: - protocol.check_ping() - # IProtocolNegotiationFactory def acceptableProtocols(self): """ diff --git a/daphne/server.py b/daphne/server.py index 7762504..8f43ee6 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -1,16 +1,21 @@ from __future__ import unicode_literals +# This has to be done first as Twisted is import-order-sensitive with reactors +from twisted.internet import asyncioreactor +asyncioreactor.install() + +import asyncio import collections import logging import random import string +import traceback import warnings from twisted.internet import reactor, defer from twisted.internet.endpoints import serverFromString from twisted.logger import globalLogBeginner, STDLibLogObserver from twisted.web import http -from twisted.python.threadpool import ThreadPool from .http_protocol import HTTPFactory from .ws_protocol import WebSocketFactory @@ -22,13 +27,12 @@ class Server(object): def __init__( self, - channel_layer, - consumer, + application, endpoints=None, signal_handlers=True, action_logger=None, http_timeout=120, - websocket_timeout=None, + websocket_timeout=86400, websocket_connect_timeout=20, ping_interval=20, ping_timeout=30, @@ -39,10 +43,9 @@ class Server(object): verbosity=1, websocket_handshake_timeout=5 ): - self.channel_layer = channel_layer - self.consumer = consumer + self.application = application self.endpoints = endpoints or [] - if len(self.endpoints) == 0: + if not self.endpoints: raise UserWarning("No endpoints. This server will not listen on anything.") self.listeners = [] self.signal_handlers = signal_handlers @@ -52,30 +55,20 @@ class Server(object): self.ping_timeout = ping_timeout self.proxy_forwarded_address_header = proxy_forwarded_address_header self.proxy_forwarded_port_header = proxy_forwarded_port_header - # If they did not provide a websocket timeout, default it to the - # channel layer's group_expiry value if present, or one day if not. - self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) + self.websocket_timeout = websocket_timeout self.websocket_connect_timeout = websocket_connect_timeout self.websocket_handshake_timeout = websocket_handshake_timeout - self.ws_protocols = ws_protocols + self.websocket_protocols = ws_protocols self.root_path = root_path self.verbosity = verbosity def run(self): - # Make the thread pool to run consumers in - # TODO: Configurable numbers of threads - self.pool = ThreadPool(name="consumers") - # Make the mapping of consumer instances to consumer channels - self.consumer_instances = {} # A set of current Twisted protocol instances to manage self.protocols = set() - # Create process-local channel prefixes - # TODO: Can we guarantee non-collision better? - process_id = "".join(random.choice(string.ascii_letters) for i in range(10)) - self.consumer_channel_prefix = "daphne.%s!" % process_id + self.application_instances = {} # Make the factory self.http_factory = HTTPFactory(self) - self.ws_factory = WebSocketFactory(self, protocols=self.ws_protocols, server='Daphne') + self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server='Daphne') self.ws_factory.setProtocolOptions( autoPingTimeout=self.ping_timeout, allowNullOrigin=True, @@ -93,18 +86,24 @@ class Server(object): else: logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)") - # Kick off the various background loops - reactor.callLater(0, self.backend_reader) + # Kick off the timeout loop + reactor.callLater(1, self.application_checker) reactor.callLater(2, self.timeout_checker) for socket_description in self.endpoints: logger.info("Listening on endpoint %s" % socket_description) - # Twisted requires str on python2 (not unicode) and str on python3 (not bytes) ep = serverFromString(reactor, str(socket_description)) self.listeners.append(ep.listen(self.http_factory)) - self.pool.start() - reactor.addSystemEventTrigger("before", "shutdown", self.pool.stop) + # Set the asyncio reactor's event loop as global + # TODO: Should we instead pass the global one into the reactor? + asyncio.set_event_loop(reactor._asyncioEventloop) + + # Verbosity 3 turns on asyncio debug to find those blocking yields + if self.verbosity >= 3: + asyncio.get_event_loop().set_debug(True) + + reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications) reactor.run(installSignalHandlers=self.signal_handlers) ### Protocol handling @@ -115,83 +114,86 @@ class Server(object): self.protocols.add(protocol) def discard_protocol(self, protocol): + # Ensure it's not in the protocol-tracking set self.protocols.discard(protocol) + # Make sure any application future that's running is cancelled + if protocol in self.application_instances: + self.application_instances[protocol].cancel() + del self.application_instances[protocol] ### Internal event/message handling - def create_consumer(self, protocol): + def create_application(self, protocol): """ - Creates a new consumer instance that fronts a Protocol instance + Creates a new application instance that fronts a Protocol instance for one of our supported protocols. Pass it the protocol, and it will work out the type, supply appropriate callables, and - put it into the server's consumer pool. - - It returns the consumer channel name, which is how you should refer - to the consumer instance. + return you the application's input queue """ - # Make sure the protocol defines a consumer type - assert protocol.consumer_type is not None - # Make it a consumer channel name - protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10)) - consumer_channel = self.consumer_channel_prefix + protocol_id - # Make an instance of the consumer - consumer_instance = self.consumer( - type=protocol.consumer_type, + # Make sure the protocol defines a application type + assert protocol.application_type is not None + # Make sure the protocol has not had another application made for it + assert protocol not in self.application_instances + # Make an instance of the application + input_queue = asyncio.Queue() + application_instance = asyncio.ensure_future(self.application( + type=protocol.application_type, + next=input_queue.get, reply=lambda message: self.handle_reply(protocol, message), - channel_layer=self.channel_layer, - consumer_channel=consumer_channel, - ) - # Assign it by channel and return it - self.consumer_instances[consumer_channel] = consumer_instance - return consumer_channel + ), loop=asyncio.get_event_loop()) + self.application_instances[protocol] = application_instance + return input_queue - def handle_message(self, consumer_channel, message): + async def handle_reply(self, protocol, message): """ - Schedules the application instance to handle the given message. + Coroutine that jumps the reply message from asyncio to Twisted """ - self.pool.callInThread(self.consumer_instances[consumer_channel], message) - - def handle_reply(self, protocol, message): - """ - Schedules the reply to be handled by the protocol in the main thread - """ - reactor.callFromThread(reactor.callLater, 0, protocol.handle_reply, message) - - ### External event/message handling - - def backend_reader(self): - """ - Runs as an-often-as-possible task with the reactor, unless there was - no result previously in which case we add a small delay. - """ - channels = [self.consumer_channel_prefix] - delay = 0 - # Quit if reactor is stopping - if not reactor.running: - logger.debug("Backend reader quitting due to reactor stop") - return - # Try to receive a message - try: - channel, message = self.channel_layer.receive(channels, block=False) - except Exception as e: - # Log the error and wait a bit to retry - logger.error('Error trying to receive messages: %s' % e) - delay = 5.00 - else: - if channel: - # Deal with the message - try: - self.handle_message(channel, message) - except Exception as e: - logger.error("Error handling external message: %s" % e) - else: - # If there's no messages, idle a little bit. - delay = 0.05 - # We can't loop inside here as this is synchronous code. - reactor.callLater(delay, self.backend_reader) + reactor.callLater(0, protocol.handle_reply, message) ### Utility + def application_checker(self): + """ + Goes through the set of current application Futures and cleans up + any that are done/prints exceptions for any that errored. + """ + for protocol, application_instance in list(self.application_instances.items()): + if application_instance.done(): + exception = application_instance.exception() + if exception: + logging.error( + "Exception inside application: {}\n{}{}".format( + exception, + "".join(traceback.format_tb( + exception.__traceback__, + )), + " {}".format(exception), + ) + ) + protocol.handle_exception(exception) + try: + del self.application_instances[protocol] + except KeyError: + # The protocol might have already got here before us. That's fine. + pass + reactor.callLater(1, self.application_checker) + + def kill_all_applications(self): + """ + Kills all application coroutines before reactor exit. + """ + # Send cancel to all coroutines + wait_for = [] + for application_instance in self.application_instances.values(): + if not application_instance.done(): + application_instance.cancel() + wait_for.append(application_instance) + logging.info("Killed %i pending application instances" % len(wait_for)) + # Make Twisted wait until they're all dead + wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for)) + wait_deferred.addErrback(lambda x: None) + return wait_deferred + def timeout_checker(self): """ Called periodically to enforce timeout rules on all connections. diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index a1ab19b..192fee2 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -20,20 +20,21 @@ class WebSocketProtocol(WebSocketServerProtocol): the websocket channels. """ + application_type = "websocket" + # If we should send no more messages (e.g. we error-closed the socket) muted = False - def __init__(self, *args, **kwargs): - super(WebSocketProtocol, self).__init__(*args, **kwargs) - self.server = self.factory.server_class - def onConnect(self, request): + self.server = self.factory.server_class + self.server.add_protocol(self) self.request = request - self.packets_received = 0 self.protocol_to_accept = None self.socket_opened = time.time() self.last_data = time.time() try: + # Make new application instance + self.application_queue = self.server.create_application(self) # Sanitize and decode headers self.clean_headers = [] for name, value in request.headers.items(): @@ -42,10 +43,6 @@ class WebSocketProtocol(WebSocketServerProtocol): if b"_" in name: continue self.clean_headers.append((name.lower(), value.encode("latin1"))) - # Make sending channel - self.reply_channel = self.main_factory.make_send_channel() - # Tell main factory about it - self.main_factory.reply_protocols[self.reply_channel] = self # Get client address if possible peer = self.transport.getPeer() host = self.transport.getHost() @@ -56,23 +53,23 @@ class WebSocketProtocol(WebSocketServerProtocol): self.client_addr = None self.server_addr = None - if self.main_factory.proxy_forwarded_address_header: + if self.server.proxy_forwarded_address_header: self.client_addr = parse_x_forwarded_for( self.http_headers, - self.main_factory.proxy_forwarded_address_header, - self.main_factory.proxy_forwarded_port_header, + self.server.proxy_forwarded_address_header, + self.server.proxy_forwarded_port_header, self.client_addr ) # Make initial request info dict from request (we only have it here) self.path = request.path.encode("ascii") - self.request_info = { + self.connect_message = { + "type": "websocket.connect", "path": self.unquote(self.path), "headers": self.clean_headers, "query_string": self._raw_query_string, # Passed by HTTP protocol "client": self.client_addr, "server": self.server_addr, - "reply_channel": self.reply_channel, "order": 0, } except: @@ -86,50 +83,33 @@ class WebSocketProtocol(WebSocketServerProtocol): if header == b'sec-websocket-protocol': protocols = [x.strip() for x in self.unquote(value).split(",")] for protocol in protocols: - if protocol in self.factory.protocols: + if protocol in self.server.websocket_protocols: ws_protocol = protocol break # Work out what subprotocol we will accept, if any - if ws_protocol and ws_protocol in self.factory.protocols: + if ws_protocol and ws_protocol in self.server.websocket_protocols: self.protocol_to_accept = ws_protocol else: self.protocol_to_accept = None # Send over the connect message - try: - self.channel_layer.send("websocket.connect", self.request_info) - except self.channel_layer.ChannelFull: - # You have to consume websocket.connect according to the spec, - # so drop the connection. - self.muted = True - logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel) - # Send code 503 "Service Unavailable" with close. - raise ConnectionDeny(code=503, reason="Connection queue at capacity") - else: - self.factory.log_action("websocket", "connecting", { - "path": self.request.path, - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - }) + self.application_queue.put_nowait(self.connect_message) + self.server.log_action("websocket", "connecting", { + "path": self.request.path, + "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, + }) # Make a deferred and return it - we'll either call it or err it later on self.handshake_deferred = defer.Deferred() return self.handshake_deferred - @classmethod - def unquote(cls, value): - """ - Python 2 and 3 compat layer for utf-8 unquoting - """ - if six.PY2: - return unquote(value).decode("utf8") - else: - return unquote(value.decode("ascii")) + ### Twisted event handling def onOpen(self): # Send news that this channel is open - logger.debug("WebSocket %s open and established", self.reply_channel) - self.factory.log_action("websocket", "connected", { + logger.debug("WebSocket %s open and established", self.client_addr) + self.server.log_action("websocket", "connected", { "path": self.request.path, "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, }) @@ -137,49 +117,78 @@ class WebSocketProtocol(WebSocketServerProtocol): def onMessage(self, payload, isBinary): # If we're muted, do nothing. if self.muted: - logger.debug("Muting incoming frame on %s", self.reply_channel) + logger.debug("Muting incoming frame on %s", self.client_addr) return - logger.debug("WebSocket incoming frame on %s", self.reply_channel) - self.packets_received += 1 + logger.debug("WebSocket incoming frame on %s", self.client_addr) self.last_data = time.time() - try: - if isBinary: - self.channel_layer.send("websocket.receive", { - "reply_channel": self.reply_channel, - "path": self.unquote(self.path), - "order": self.packets_received, - "bytes": payload, - }) + if isBinary: + self.application_queue.put_nowait({ + "type": "websocket.receive", + "bytes": payload, + }) + else: + self.application_queue.put_nowait({ + "type": "websocket.receive", + "text": payload.decode("utf8"), + }) + + def onClose(self, wasClean, code, reason): + """ + Called when Twisted closes the socket. + """ + self.server.discard_protocol(self) + logger.debug("WebSocket closed for %s", self.client_addr) + if not self.muted: + self.application_queue.put_nowait({ + "type": "websocket.disconnect", + "code": code, + }) + self.server.log_action("websocket", "disconnected", { + "path": self.request.path, + "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, + }) + + ### Internal event handling + + def handle_reply(self, message): + if "type" not in message: + raise ValueError("Message has no type defined") + if message["type"] == "websocket.accept": + self.serverAccept() + elif message["type"] == "websocket.close": + if self.state == self.STATE_CONNECTING: + self.serverReject() else: - self.channel_layer.send("websocket.receive", { - "reply_channel": self.reply_channel, - "path": self.unquote(self.path), - "order": self.packets_received, - "text": payload.decode("utf8"), - }) - except self.channel_layer.ChannelFull: - # You have to consume websocket.receive according to the spec, - # so drop the connection. - self.muted = True - logger.warn("WebSocket force closed for %s due to receive backpressure", self.reply_channel) - # Send code 1013 "try again later" with close. - self.sendCloseFrame(code=1013, isReply=False) + self.serverClose(code=message.get("code", None)) + elif message["type"] == "websocket.send": + if self.state == self.STATE_CONNECTING: + raise ValueError("Socket has not been accepted, so cannot send over it") + if message.get("bytes", None) and message.get("text", None): + raise ValueError( + "Got invalid WebSocket reply message on %s - contains both bytes and text keys" % ( + channel, + ) + ) + if message.get("bytes", None): + self.serverSend(message["bytes"], True) + if message.get("text", None): + self.serverSend(message["text"], False) def serverAccept(self): """ Called when we get a message saying to accept the connection. """ self.handshake_deferred.callback(self.protocol_to_accept) - logger.debug("WebSocket %s accepted by application", self.reply_channel) + logger.debug("WebSocket %s accepted by application", self.client_addr) def serverReject(self): """ Called when we get a message saying to reject the connection. """ self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied")) - self.cleanup() - logger.debug("WebSocket %s rejected by application", self.reply_channel) - self.factory.log_action("websocket", "rejected", { + self.server.discard_protocol(self) + logger.debug("WebSocket %s rejected by application", self.client_addr) + self.server.log_action("websocket", "rejected", { "path": self.request.path, "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, }) @@ -191,47 +200,30 @@ class WebSocketProtocol(WebSocketServerProtocol): if self.state == self.STATE_CONNECTING: self.serverAccept() self.last_data = time.time() - logger.debug("Sent WebSocket packet to client for %s", self.reply_channel) + logger.debug("Sent WebSocket packet to client for %s", self.client_addr) if binary: self.sendMessage(content, binary) else: self.sendMessage(content.encode("utf8"), binary) - def serverClose(self, code=True): + def serverClose(self, code=None): """ Server-side channel message to close the socket """ - code = 1000 if code is True else code + code = 1000 if code is None else code self.sendClose(code=code) - def onClose(self, wasClean, code, reason): - self.cleanup() - if hasattr(self, "reply_channel"): - logger.debug("WebSocket closed for %s", self.reply_channel) - try: - if not self.muted: - self.channel_layer.send("websocket.disconnect", { - "reply_channel": self.reply_channel, - "code": code, - "path": self.unquote(self.path), - "order": self.packets_received + 1, - }) - except self.channel_layer.ChannelFull: - pass - self.factory.log_action("websocket", "disconnected", { - "path": self.request.path, - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - }) - else: - logger.debug("WebSocket closed before handshake established") + ### Utils - def cleanup(self): + @classmethod + def unquote(cls, value): """ - Call to clean up this socket after it's closed. + Python 2 and 3 compat layer for utf-8 unquoting """ - if hasattr(self, "reply_channel"): - if self.reply_channel in self.factory.reply_protocols: - del self.factory.reply_protocols[self.reply_channel] + if six.PY2: + return unquote(value).decode("utf8") + else: + return unquote(value.decode("ascii")) def duration(self): """ @@ -275,3 +267,15 @@ class WebSocketFactory(WebSocketServerFactory): def __init__(self, server_class, *args, **kwargs): self.server_class = server_class WebSocketServerFactory.__init__(self, *args, **kwargs) + + def buildProtocol(self, addr): + """ + Builds protocol instances. We use this to inject the factory object into the protocol. + """ + try: + protocol = super(WebSocketFactory, self).buildProtocol(addr) + protocol.factory = self + return protocol + except Exception as e: + logger.error("Cannot build protocol: %s" % traceback.format_exc()) + raise diff --git a/setup.py b/setup.py index fd45283..0cee5ea 100755 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ setup( include_package_data=True, install_requires=[ 'asgiref~=1.1', - 'twisted>=17.1', + 'twisted>=17.5', 'autobahn>=0.18', ], extras_require={