diff --git a/.gitignore b/.gitignore index e3535d3..830060b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ build/ .hypothesis .cache .eggs +test_layer* +test_consumer* diff --git a/daphne/cli.py b/daphne/cli.py index b27cb06..a35f51e 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -2,8 +2,10 @@ import sys import argparse import logging import importlib -from .server import Server, build_endpoint_description_strings +from .server import Server +from .endpoints import build_endpoint_description_strings from .access import AccessLogGenerator +from .utils import import_by_path logger = logging.getLogger(__name__) @@ -81,7 +83,7 @@ class CommandLineInterface(object): '-t', '--http-timeout', type=int, - help='How long to wait for worker server before timing out HTTP connections', + help='How long to wait for worker before timing out HTTP connections', default=120, ) self.parser.add_argument( @@ -123,16 +125,19 @@ class CommandLineInterface(object): action='store_true', ) self.parser.add_argument( - '--force-sync', - dest='force_sync', - action='store_true', - help='Force the server to use synchronous mode on its ASGI channel layer', - default=False, + '--threads', + help='Number of threads to run the application in', + type=int, + default=2, ) self.parser.add_argument( 'channel_layer', help='The ASGI channel layer instance to use as path.to.module:instance.path', ) + self.parser.add_argument( + 'consumer', + help='The consumer to dispatch to as path.to.module:instance.path', + ) self.server = None @@ -168,13 +173,11 @@ class CommandLineInterface(object): access_log_stream = open(args.access_log, "a", 1) elif args.verbosity >= 1: access_log_stream = sys.stdout - # Import channel layer + # Import application and channel_layer sys.path.insert(0, ".") - module_path, object_path = args.channel_layer.split(":", 1) - channel_layer = importlib.import_module(module_path) - for bit in object_path.split("."): - channel_layer = getattr(channel_layer, bit) - + channel_layer = import_by_path(args.channel_layer) + consumer = import_by_path(args.consumer) + # Set up port/host bindings if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]): # no advanced binding options passed, patch in defaults args.host = DEFAULT_HOST @@ -183,8 +186,7 @@ class CommandLineInterface(object): args.port = DEFAULT_PORT elif args.port and not args.host: args.host = DEFAULT_HOST - - # build endpoint description strings from (optional) cli arguments + # Build endpoint description strings from (optional) cli arguments endpoints = build_endpoint_description_strings( host=args.host, port=args.port, @@ -194,13 +196,14 @@ class CommandLineInterface(object): endpoints = sorted( args.socket_strings + endpoints ) + # Start the server logger.info( 'Starting server at %s, channel layer %s.' % (', '.join(endpoints), args.channel_layer) ) - self.server = Server( channel_layer=channel_layer, + consumer=consumer, endpoints=endpoints, http_timeout=args.http_timeout, ping_interval=args.ping_interval, @@ -213,6 +216,5 @@ class CommandLineInterface(object): verbosity=args.verbosity, proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None, proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None, - force_sync=args.force_sync, ) self.server.run() diff --git a/daphne/endpoints.py b/daphne/endpoints.py new file mode 100644 index 0000000..5904546 --- /dev/null +++ b/daphne/endpoints.py @@ -0,0 +1,25 @@ +def build_endpoint_description_strings( + host=None, + port=None, + unix_socket=None, + file_descriptor=None + ): + """ + Build a list of twisted endpoint description strings that the server will listen on. + This is to streamline the generation of twisted endpoint description strings from easier + to use command line args such as host, port, unix sockets etc. + """ + socket_descriptions = [] + if host and port: + host = host.strip('[]').replace(':', '\:') + socket_descriptions.append('tcp:port=%d:interface=%s' % (int(port), host)) + elif any([host, port]): + raise ValueError('TCP binding requires both port and host kwargs.') + + if unix_socket: + socket_descriptions.append('unix:%s' % unix_socket) + + if file_descriptor is not None: + socket_descriptions.append('fd:fileno=%d' % int(file_descriptor)) + + return socket_descriptions diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 32dc276..0e87b49 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -29,6 +29,8 @@ class WebRequest(http.Request): GET and POST out. """ + consumer_type = "http" + error_template = """ @@ -51,18 +53,17 @@ class WebRequest(http.Request): def __init__(self, *args, **kwargs): try: http.Request.__init__(self, *args, **kwargs) - # Easy factory link - self.factory = self.channel.factory - # Make a name for our reply channel - self.reply_channel = self.factory.make_send_channel() - # Tell factory we're that channel's client - self.last_keepalive = time.time() - self.factory.reply_protocols[self.reply_channel] = self - self._got_response_start = False + # Easy server link + self.server = self.channel.factory.server + self.consumer_channel = None + self._response_started = False + self.server.add_protocol(self) except Exception: logger.error(traceback.format_exc()) raise + ### Twisted progress callbacks + def process(self): try: self.request_start = time.time() @@ -79,15 +80,14 @@ class WebRequest(http.Request): else: self.client_addr = None self.server_addr = None - - if self.factory.proxy_forwarded_address_header: + # See if we need to get the address from a proxy header instead + if self.server.proxy_forwarded_address_header: self.client_addr = parse_x_forwarded_for( self.requestHeaders, - self.factory.proxy_forwarded_address_header, - self.factory.proxy_forwarded_port_header, + self.server.proxy_forwarded_address_header, + self.server.proxy_forwarded_port_header, self.client_addr ) - # Check for unicodeish path (or it'll crash when trying to parse) try: self.path.decode("ascii") @@ -102,7 +102,7 @@ class WebRequest(http.Request): # Is it WebSocket? IS IT?! if upgrade_header and upgrade_header.lower() == b"websocket": # Make WebSocket protocol to hand off to - protocol = self.factory.ws_factory.buildProtocol(self.transport.getPeer()) + protocol = self.server.ws_factory.buildProtocol(self.transport.getPeer()) if not protocol: # If protocol creation fails, we signal "internal server error" self.setResponseCode(500) @@ -111,7 +111,6 @@ class WebRequest(http.Request): # Give it the raw query string protocol._raw_query_string = self.query_string # Port across transport - protocol.set_main_factory(self.factory) transport, self.transport = self.transport, None if isinstance(transport, ProtocolWrapper): # i.e. TLS is a wrapping protocol @@ -127,12 +126,7 @@ class WebRequest(http.Request): data += self.content.read() protocol.dataReceived(data) # Remove our HTTP reply channel association - if hasattr(protocol, "reply_channel"): - logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel) - else: - logger.debug("Connection %s did not get successful WS handshake.", self.reply_channel) - del self.factory.reply_protocols[self.reply_channel] - self.reply_channel = None + logger.debug("Upgraded connection %s to WebSocket", self.client_addr) # Resume the producer so we keep getting data, if it's available as a method # 17.1 version if hasattr(self.channel, "_networkProducer"): @@ -143,9 +137,11 @@ class WebRequest(http.Request): # Boring old HTTP. else: + # Create consumer to handle this connection + self.consumer_channel = self.server.create_consumer(self) # Sanitize and decode headers, potentially extracting root path self.clean_headers = [] - self.root_path = self.factory.root_path + self.root_path = self.server.root_path for name, values in self.requestHeaders.getAllRawHeaders(): # Prevent CVE-2015-0219 if b"_" in name: @@ -155,12 +151,13 @@ class WebRequest(http.Request): self.root_path = self.unquote(value) else: self.clean_headers.append((name.lower(), value)) - logger.debug("HTTP %s request for %s", self.method, self.reply_channel) + logger.debug("HTTP %s request for %s", self.method, self.consumer_channel) self.content.seek(0, 0) - # Send message - try: - self.factory.channel_layer.send("http.request", { - "reply_channel": self.reply_channel, + # Run consumer + self.server.handle_message( + self.consumer_channel, + { + "type": "http.request", # TODO: Correctly say if it's 1.1 or 1.0 "http_version": self.clientproto.split(b"/")[-1].decode("ascii"), "method": self.method.decode("ascii"), @@ -172,14 +169,90 @@ class WebRequest(http.Request): "body": self.content.read(), "client": self.client_addr, "server": self.server_addr, - }) - except self.factory.channel_layer.ChannelFull: - # Channel is too full; reject request with 503 - self.basic_error(503, b"Service Unavailable", "Request queue full.") + }, + ) except Exception: logger.error(traceback.format_exc()) self.basic_error(500, b"Internal Server Error", "HTTP processing error") + def connectionLost(self, reason): + """ + Cleans up reply channel on close. + """ + if self.consumer_channel: + self.send_disconnect() + logger.debug("HTTP disconnect for %s", self.client_addr) + http.Request.connectionLost(self, reason) + self.server.discard_protocol(self) + + def finish(self): + """ + Cleans up reply channel on close. + """ + if self.consumer_channel: + self.send_disconnect() + logger.debug("HTTP close for %s", self.client_addr) + http.Request.finish(self) + self.server.discard_protocol(self) + + ### Server reply callbacks + + def handle_reply(self, message): + """ + Handles a reply from the client + """ + if message['type'] == "http.response": + if self._response_started: + raise ValueError("HTTP response has already been started") + self._response_started = True + if 'status' not in message: + raise ValueError("Specifying a status code is required for a Response message.") + # Set HTTP status code + self.setResponseCode(message['status']) + # Write headers + for header, value in message.get("headers", {}): + # Shim code from old ASGI version, can be removed after a while + if isinstance(header, six.text_type): + header = header.encode("latin1") + self.responseHeaders.addRawHeader(header, value) + logger.debug("HTTP %s response started for %s", message['status'], self.client_addr) + elif message['type'] == "http.response.content": + if not self._response_started: + raise ValueError("HTTP response has not yet been started") + else: + raise ValueError("Cannot handle message type %s!" % message['type']) + + # Write out body + http.Request.write(self, message.get('content', b'')) + + # End if there's no more content + if not message.get("more_content", False): + self.finish() + logger.debug("HTTP response complete for %s", self.client_addr) + try: + self.server.log_action("http", "complete", { + "path": self.uri.decode("ascii"), + "status": self.code, + "method": self.method.decode("ascii"), + "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, + "time_taken": self.duration(), + "size": self.sentLength, + }) + except Exception as e: + logging.error(traceback.format_exc()) + else: + logger.debug("HTTP response chunk for %s", self.client_addr) + + def check_timeouts(self): + """ + Called periodically to see if we should timeout something + """ + # Web timeout checking + if self.duration() > self.server.http_timeout: + self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.") + + ### Utility functions + @classmethod def unquote(cls, value, plus_as_space=False): """ @@ -198,81 +271,17 @@ class WebRequest(http.Request): def send_disconnect(self): """ - Sends a disconnect message on the http.disconnect channel. + Sends a http.disconnect message. Useful only really for long-polling. """ # If we don't yet have a path, then don't send as we never opened. if self.path: - try: - self.factory.channel_layer.send("http.disconnect", { - "reply_channel": self.reply_channel, - "path": self.unquote(self.path), - }) - except self.factory.channel_layer.ChannelFull: - pass - - def connectionLost(self, reason): - """ - Cleans up reply channel on close. - """ - if self.reply_channel and self.reply_channel in self.channel.factory.reply_protocols: - self.send_disconnect() - del self.channel.factory.reply_protocols[self.reply_channel] - logger.debug("HTTP disconnect for %s", self.reply_channel) - http.Request.connectionLost(self, reason) - - def finish(self): - """ - Cleans up reply channel on close. - """ - if self.reply_channel and self.reply_channel in self.channel.factory.reply_protocols: - self.send_disconnect() - del self.channel.factory.reply_protocols[self.reply_channel] - logger.debug("HTTP close for %s", self.reply_channel) - http.Request.finish(self) - - def serverResponse(self, message): - """ - Writes a received HTTP response back out to the transport. - """ - if not self._got_response_start: - self._got_response_start = True - if 'status' not in message: - raise ValueError("Specifying a status code is required for a Response message.") - - # Set HTTP status code - self.setResponseCode(message['status']) - # Write headers - for header, value in message.get("headers", {}): - # Shim code from old ASGI version, can be removed after a while - if isinstance(header, six.text_type): - header = header.encode("latin1") - self.responseHeaders.addRawHeader(header, value) - logger.debug("HTTP %s response started for %s", message['status'], self.reply_channel) - else: - if 'status' in message: - raise ValueError("Got multiple Response messages for %s!" % self.reply_channel) - - # Write out body - http.Request.write(self, message.get('content', b'')) - - # End if there's no more content - if not message.get("more_content", False): - self.finish() - logger.debug("HTTP response complete for %s", self.reply_channel) - try: - self.factory.log_action("http", "complete", { - "path": self.uri.decode("ascii"), - "status": self.code, - "method": self.method.decode("ascii"), - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - "time_taken": self.duration(), - "size": self.sentLength, - }) - except Exception as e: - logging.error(traceback.format_exc()) - else: - logger.debug("HTTP response chunk for %s", self.reply_channel) + self.server.handle_message( + self.consumer_channel, + { + "type": "http.disconnect", + }, + ) def duration(self): """ @@ -298,6 +307,12 @@ class WebRequest(http.Request): }).encode("utf8"), }) + def __hash__(self): + return hash(id(self)) + + def __eq__(self, other): + return id(self) == id(other) + @implementer(IProtocolNegotiationFactory) class HTTPFactory(http.HTTPFactory): @@ -308,30 +323,9 @@ class HTTPFactory(http.HTTPFactory): routed appropriately. """ - def __init__(self, channel_layer, action_logger=None, send_channel=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30, proxy_forwarded_address_header=None, proxy_forwarded_port_header=None, websocket_handshake_timeout=5): + def __init__(self, server): http.HTTPFactory.__init__(self) - self.channel_layer = channel_layer - self.action_logger = action_logger - self.send_channel = send_channel - assert self.send_channel is not None - self.timeout = timeout - self.websocket_timeout = websocket_timeout - self.websocket_connect_timeout = websocket_connect_timeout - self.ping_interval = ping_interval - self.proxy_forwarded_address_header = proxy_forwarded_address_header - self.proxy_forwarded_port_header = proxy_forwarded_port_header - # We track all sub-protocols for response channel mapping - self.reply_protocols = {} - # Make a factory for WebSocket protocols - self.ws_factory = WebSocketFactory(self, protocols=ws_protocols, server='Daphne') - self.ws_factory.setProtocolOptions( - autoPingTimeout=ping_timeout, - allowNullOrigin=True, - openHandshakeTimeout=websocket_handshake_timeout - ) - self.ws_factory.protocol = WebSocketProtocol - self.ws_factory.reply_protocols = self.reply_protocols - self.root_path = root_path + self.server = server def buildProtocol(self, addr): """ @@ -353,9 +347,6 @@ class HTTPFactory(http.HTTPFactory): protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10)) return self.send_channel + protocol_id - def reply_channels(self): - return self.reply_protocols.keys() - def dispatch_reply(self, channel, message): # If we don't know about the channel, ignore it (likely a channel we # used to have that's now in a group). @@ -408,13 +399,6 @@ class HTTPFactory(http.HTTPFactory): else: raise ValueError("Unknown protocol class") - def log_action(self, protocol, action, details): - """ - Dispatches to any registered action logger, if there is one. - """ - if self.action_logger: - self.action_logger(protocol, action, details) - def check_timeouts(self): """ Runs through all HTTP protocol instances and times them out if they've diff --git a/daphne/server.py b/daphne/server.py index 9340524..7762504 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import collections import logging import random import string @@ -9,8 +10,10 @@ from twisted.internet import reactor, defer from twisted.internet.endpoints import serverFromString from twisted.logger import globalLogBeginner, STDLibLogObserver from twisted.web import http +from twisted.python.threadpool import ThreadPool from .http_protocol import HTTPFactory +from .ws_protocol import WebSocketFactory logger = logging.getLogger(__name__) @@ -20,11 +23,8 @@ class Server(object): def __init__( self, channel_layer, - host=None, - port=None, + consumer, endpoints=None, - unix_socket=None, - file_descriptor=None, signal_handlers=True, action_logger=None, http_timeout=120, @@ -36,28 +36,14 @@ class Server(object): root_path="", proxy_forwarded_address_header=None, proxy_forwarded_port_header=None, - force_sync=False, verbosity=1, websocket_handshake_timeout=5 ): self.channel_layer = channel_layer + self.consumer = consumer self.endpoints = endpoints or [] - - if any([host, port, unix_socket, file_descriptor]): - warnings.warn(''' - The host/port/unix_socket/file_descriptor keyword arguments to %s are deprecated. - ''' % self.__class__.__name__, DeprecationWarning) - # build endpoint description strings from deprecated kwargs - self.endpoints = sorted(self.endpoints + build_endpoint_description_strings( - host=host, - port=port, - unix_socket=unix_socket, - file_descriptor=file_descriptor - )) - if len(self.endpoints) == 0: raise UserWarning("No endpoints. This server will not listen on anything.") - self.listeners = [] self.signal_handlers = signal_handlers self.action_logger = action_logger @@ -73,29 +59,27 @@ class Server(object): self.websocket_handshake_timeout = websocket_handshake_timeout self.ws_protocols = ws_protocols self.root_path = root_path - self.force_sync = force_sync self.verbosity = verbosity def run(self): + # Make the thread pool to run consumers in + # TODO: Configurable numbers of threads + self.pool = ThreadPool(name="consumers") + # Make the mapping of consumer instances to consumer channels + self.consumer_instances = {} + # A set of current Twisted protocol instances to manage + self.protocols = set() # Create process-local channel prefixes # TODO: Can we guarantee non-collision better? process_id = "".join(random.choice(string.ascii_letters) for i in range(10)) - self.send_channel = "daphne.response.%s!" % process_id + self.consumer_channel_prefix = "daphne.%s!" % process_id # Make the factory - self.factory = HTTPFactory( - self.channel_layer, - action_logger=self.action_logger, - send_channel=self.send_channel, - timeout=self.http_timeout, - websocket_timeout=self.websocket_timeout, - websocket_connect_timeout=self.websocket_connect_timeout, - ping_interval=self.ping_interval, - ping_timeout=self.ping_timeout, - ws_protocols=self.ws_protocols, - root_path=self.root_path, - proxy_forwarded_address_header=self.proxy_forwarded_address_header, - proxy_forwarded_port_header=self.proxy_forwarded_port_header, - websocket_handshake_timeout=self.websocket_handshake_timeout + self.http_factory = HTTPFactory(self) + self.ws_factory = WebSocketFactory(self, protocols=self.ws_protocols, server='Daphne') + self.ws_factory.setProtocolOptions( + autoPingTimeout=self.ping_timeout, + allowNullOrigin=True, + openHandshakeTimeout=self.websocket_handshake_timeout ) if self.verbosity <= 1: # Redirect the Twisted log to nowhere @@ -109,28 +93,78 @@ class Server(object): else: logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)") - if "twisted" in self.channel_layer.extensions and not self.force_sync: - logger.info("Using native Twisted mode on channel layer") - reactor.callLater(0, self.backend_reader_twisted) - else: - logger.info("Using busy-loop synchronous mode on channel layer") - reactor.callLater(0, self.backend_reader_sync) + # Kick off the various background loops + reactor.callLater(0, self.backend_reader) reactor.callLater(2, self.timeout_checker) for socket_description in self.endpoints: logger.info("Listening on endpoint %s" % socket_description) # Twisted requires str on python2 (not unicode) and str on python3 (not bytes) ep = serverFromString(reactor, str(socket_description)) - self.listeners.append(ep.listen(self.factory)) + self.listeners.append(ep.listen(self.http_factory)) + self.pool.start() + reactor.addSystemEventTrigger("before", "shutdown", self.pool.stop) reactor.run(installSignalHandlers=self.signal_handlers) - def backend_reader_sync(self): + ### Protocol handling + + def add_protocol(self, protocol): + if protocol in self.protocols: + raise RuntimeError("Protocol %r was added to main list twice!" % protocol) + self.protocols.add(protocol) + + def discard_protocol(self, protocol): + self.protocols.discard(protocol) + + ### Internal event/message handling + + def create_consumer(self, protocol): + """ + Creates a new consumer instance that fronts a Protocol instance + for one of our supported protocols. Pass it the protocol, + and it will work out the type, supply appropriate callables, and + put it into the server's consumer pool. + + It returns the consumer channel name, which is how you should refer + to the consumer instance. + """ + # Make sure the protocol defines a consumer type + assert protocol.consumer_type is not None + # Make it a consumer channel name + protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10)) + consumer_channel = self.consumer_channel_prefix + protocol_id + # Make an instance of the consumer + consumer_instance = self.consumer( + type=protocol.consumer_type, + reply=lambda message: self.handle_reply(protocol, message), + channel_layer=self.channel_layer, + consumer_channel=consumer_channel, + ) + # Assign it by channel and return it + self.consumer_instances[consumer_channel] = consumer_instance + return consumer_channel + + def handle_message(self, consumer_channel, message): + """ + Schedules the application instance to handle the given message. + """ + self.pool.callInThread(self.consumer_instances[consumer_channel], message) + + def handle_reply(self, protocol, message): + """ + Schedules the reply to be handled by the protocol in the main thread + """ + reactor.callFromThread(reactor.callLater, 0, protocol.handle_reply, message) + + ### External event/message handling + + def backend_reader(self): """ Runs as an-often-as-possible task with the reactor, unless there was no result previously in which case we add a small delay. """ - channels = [self.send_channel] + channels = [self.consumer_channel_prefix] delay = 0 # Quit if reactor is stopping if not reactor.running: @@ -147,75 +181,29 @@ class Server(object): if channel: # Deal with the message try: - self.factory.dispatch_reply(channel, message) + self.handle_message(channel, message) except Exception as e: - logger.error("HTTP/WS send decode error: %s" % e) + logger.error("Error handling external message: %s" % e) else: # If there's no messages, idle a little bit. delay = 0.05 # We can't loop inside here as this is synchronous code. - reactor.callLater(delay, self.backend_reader_sync) + reactor.callLater(delay, self.backend_reader) - @defer.inlineCallbacks - def backend_reader_twisted(self): - """ - Runs as an-often-as-possible task with the reactor, unless there was - no result previously in which case we add a small delay. - """ - channels = [self.send_channel] - while True: - if not reactor.running: - logging.debug("Backend reader quitting due to reactor stop") - return - try: - channel, message = yield self.channel_layer.receive_twisted(channels) - except Exception as e: - logger.error('Error trying to receive messages: %s' % e) - yield self.sleep(5.00) - else: - # Deal with the message - if channel: - try: - self.factory.dispatch_reply(channel, message) - except Exception as e: - logger.error("HTTP/WS send decode error: %s" % e) - - def sleep(self, delay): - d = defer.Deferred() - reactor.callLater(delay, d.callback, None) - return d + ### Utility def timeout_checker(self): """ Called periodically to enforce timeout rules on all connections. Also checks pings at the same time. """ - self.factory.check_timeouts() + for protocol in self.protocols: + protocol.check_timeouts() reactor.callLater(2, self.timeout_checker) - -def build_endpoint_description_strings( - host=None, - port=None, - unix_socket=None, - file_descriptor=None - ): - """ - Build a list of twisted endpoint description strings that the server will listen on. - This is to streamline the generation of twisted endpoint description strings from easier - to use command line args such as host, port, unix sockets etc. - """ - socket_descriptions = [] - if host and port: - host = host.strip('[]').replace(':', '\:') - socket_descriptions.append('tcp:port=%d:interface=%s' % (int(port), host)) - elif any([host, port]): - raise ValueError('TCP binding requires both port and host kwargs.') - - if unix_socket: - socket_descriptions.append('unix:%s' % unix_socket) - - if file_descriptor is not None: - socket_descriptions.append('fd:fileno=%d' % int(file_descriptor)) - - return socket_descriptions + def log_action(self, protocol, action, details): + """ + Dispatches to any registered action logger, if there is one. + """ + if self.action_logger: + self.action_logger(protocol, action, details) diff --git a/daphne/utils.py b/daphne/utils.py index cb8043c..9ba6291 100644 --- a/daphne/utils.py +++ b/daphne/utils.py @@ -1,6 +1,20 @@ +import sys +import importlib from twisted.web.http_headers import Headers +def import_by_path(path): + """ + Given a dotted/colon path, like project.module:ClassName.callable, + returns the object at the end of the path. + """ + module_path, object_path = path.split(":", 1) + target = importlib.import_module(module_path) + for bit in object_path.split("."): + target = getattr(target, bit) + return target + + def header_value(headers, header_name): value = headers[header_name] if isinstance(value, list): diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 502f340..a1ab19b 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -23,9 +23,9 @@ class WebSocketProtocol(WebSocketServerProtocol): # If we should send no more messages (e.g. we error-closed the socket) muted = False - def set_main_factory(self, main_factory): - self.main_factory = main_factory - self.channel_layer = self.main_factory.channel_layer + def __init__(self, *args, **kwargs): + super(WebSocketProtocol, self).__init__(*args, **kwargs) + self.server = self.factory.server_class def onConnect(self, request): self.request = request @@ -239,19 +239,29 @@ class WebSocketProtocol(WebSocketServerProtocol): """ return time.time() - self.socket_opened - def check_ping(self): + def check_timeouts(self): """ - Checks to see if we should send a keepalive ping/deny socket connection + Called periodically to see if we should timeout something """ + # Web timeout checking + if self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0: + self.serverClose() + # Ping check # If we're still connecting, deny the connection if self.state == self.STATE_CONNECTING: - if self.duration() > self.main_factory.websocket_connect_timeout: + if self.duration() > self.server.websocket_connect_timeout: self.serverReject() elif self.state == self.STATE_OPEN: - if (time.time() - self.last_data) > self.main_factory.ping_interval: + if (time.time() - self.last_data) > self.server.ping_interval: self._sendAutoPing() self.last_data = time.time() + def __hash__(self): + return hash(id(self)) + + def __eq__(self, other): + return id(self) == id(other) + class WebSocketFactory(WebSocketServerFactory): """ @@ -260,9 +270,8 @@ class WebSocketFactory(WebSocketServerFactory): to get reply ID info. """ - def __init__(self, main_factory, *args, **kwargs): - self.main_factory = main_factory - WebSocketServerFactory.__init__(self, *args, **kwargs) + protocol = WebSocketProtocol - def log_action(self, *args, **kwargs): - self.main_factory.log_action(*args, **kwargs) + def __init__(self, server_class, *args, **kwargs): + self.server_class = server_class + WebSocketServerFactory.__init__(self, *args, **kwargs)