Initial refactor to get HTTP working in new style

This commit is contained in:
Andrew Godwin 2017-08-07 14:15:35 +10:00
parent f3b5d854ca
commit a656c9f4c6
7 changed files with 290 additions and 266 deletions

2
.gitignore vendored
View File

@ -7,3 +7,5 @@ build/
.hypothesis
.cache
.eggs
test_layer*
test_consumer*

View File

@ -2,8 +2,10 @@ import sys
import argparse
import logging
import importlib
from .server import Server, build_endpoint_description_strings
from .server import Server
from .endpoints import build_endpoint_description_strings
from .access import AccessLogGenerator
from .utils import import_by_path
logger = logging.getLogger(__name__)
@ -81,7 +83,7 @@ class CommandLineInterface(object):
'-t',
'--http-timeout',
type=int,
help='How long to wait for worker server before timing out HTTP connections',
help='How long to wait for worker before timing out HTTP connections',
default=120,
)
self.parser.add_argument(
@ -123,16 +125,19 @@ class CommandLineInterface(object):
action='store_true',
)
self.parser.add_argument(
'--force-sync',
dest='force_sync',
action='store_true',
help='Force the server to use synchronous mode on its ASGI channel layer',
default=False,
'--threads',
help='Number of threads to run the application in',
type=int,
default=2,
)
self.parser.add_argument(
'channel_layer',
help='The ASGI channel layer instance to use as path.to.module:instance.path',
)
self.parser.add_argument(
'consumer',
help='The consumer to dispatch to as path.to.module:instance.path',
)
self.server = None
@ -168,13 +173,11 @@ class CommandLineInterface(object):
access_log_stream = open(args.access_log, "a", 1)
elif args.verbosity >= 1:
access_log_stream = sys.stdout
# Import channel layer
# Import application and channel_layer
sys.path.insert(0, ".")
module_path, object_path = args.channel_layer.split(":", 1)
channel_layer = importlib.import_module(module_path)
for bit in object_path.split("."):
channel_layer = getattr(channel_layer, bit)
channel_layer = import_by_path(args.channel_layer)
consumer = import_by_path(args.consumer)
# Set up port/host bindings
if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]):
# no advanced binding options passed, patch in defaults
args.host = DEFAULT_HOST
@ -183,8 +186,7 @@ class CommandLineInterface(object):
args.port = DEFAULT_PORT
elif args.port and not args.host:
args.host = DEFAULT_HOST
# build endpoint description strings from (optional) cli arguments
# Build endpoint description strings from (optional) cli arguments
endpoints = build_endpoint_description_strings(
host=args.host,
port=args.port,
@ -194,13 +196,14 @@ class CommandLineInterface(object):
endpoints = sorted(
args.socket_strings + endpoints
)
# Start the server
logger.info(
'Starting server at %s, channel layer %s.' %
(', '.join(endpoints), args.channel_layer)
)
self.server = Server(
channel_layer=channel_layer,
consumer=consumer,
endpoints=endpoints,
http_timeout=args.http_timeout,
ping_interval=args.ping_interval,
@ -213,6 +216,5 @@ class CommandLineInterface(object):
verbosity=args.verbosity,
proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None,
proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None,
force_sync=args.force_sync,
)
self.server.run()

25
daphne/endpoints.py Normal file
View File

@ -0,0 +1,25 @@
def build_endpoint_description_strings(
host=None,
port=None,
unix_socket=None,
file_descriptor=None
):
"""
Build a list of twisted endpoint description strings that the server will listen on.
This is to streamline the generation of twisted endpoint description strings from easier
to use command line args such as host, port, unix sockets etc.
"""
socket_descriptions = []
if host and port:
host = host.strip('[]').replace(':', '\:')
socket_descriptions.append('tcp:port=%d:interface=%s' % (int(port), host))
elif any([host, port]):
raise ValueError('TCP binding requires both port and host kwargs.')
if unix_socket:
socket_descriptions.append('unix:%s' % unix_socket)
if file_descriptor is not None:
socket_descriptions.append('fd:fileno=%d' % int(file_descriptor))
return socket_descriptions

View File

@ -29,6 +29,8 @@ class WebRequest(http.Request):
GET and POST out.
"""
consumer_type = "http"
error_template = """
<html>
<head>
@ -51,18 +53,17 @@ class WebRequest(http.Request):
def __init__(self, *args, **kwargs):
try:
http.Request.__init__(self, *args, **kwargs)
# Easy factory link
self.factory = self.channel.factory
# Make a name for our reply channel
self.reply_channel = self.factory.make_send_channel()
# Tell factory we're that channel's client
self.last_keepalive = time.time()
self.factory.reply_protocols[self.reply_channel] = self
self._got_response_start = False
# Easy server link
self.server = self.channel.factory.server
self.consumer_channel = None
self._response_started = False
self.server.add_protocol(self)
except Exception:
logger.error(traceback.format_exc())
raise
### Twisted progress callbacks
def process(self):
try:
self.request_start = time.time()
@ -79,15 +80,14 @@ class WebRequest(http.Request):
else:
self.client_addr = None
self.server_addr = None
if self.factory.proxy_forwarded_address_header:
# See if we need to get the address from a proxy header instead
if self.server.proxy_forwarded_address_header:
self.client_addr = parse_x_forwarded_for(
self.requestHeaders,
self.factory.proxy_forwarded_address_header,
self.factory.proxy_forwarded_port_header,
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.client_addr
)
# Check for unicodeish path (or it'll crash when trying to parse)
try:
self.path.decode("ascii")
@ -102,7 +102,7 @@ class WebRequest(http.Request):
# Is it WebSocket? IS IT?!
if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to
protocol = self.factory.ws_factory.buildProtocol(self.transport.getPeer())
protocol = self.server.ws_factory.buildProtocol(self.transport.getPeer())
if not protocol:
# If protocol creation fails, we signal "internal server error"
self.setResponseCode(500)
@ -111,7 +111,6 @@ class WebRequest(http.Request):
# Give it the raw query string
protocol._raw_query_string = self.query_string
# Port across transport
protocol.set_main_factory(self.factory)
transport, self.transport = self.transport, None
if isinstance(transport, ProtocolWrapper):
# i.e. TLS is a wrapping protocol
@ -127,12 +126,7 @@ class WebRequest(http.Request):
data += self.content.read()
protocol.dataReceived(data)
# Remove our HTTP reply channel association
if hasattr(protocol, "reply_channel"):
logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
else:
logger.debug("Connection %s did not get successful WS handshake.", self.reply_channel)
del self.factory.reply_protocols[self.reply_channel]
self.reply_channel = None
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
# Resume the producer so we keep getting data, if it's available as a method
# 17.1 version
if hasattr(self.channel, "_networkProducer"):
@ -143,9 +137,11 @@ class WebRequest(http.Request):
# Boring old HTTP.
else:
# Create consumer to handle this connection
self.consumer_channel = self.server.create_consumer(self)
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
self.root_path = self.factory.root_path
self.root_path = self.server.root_path
for name, values in self.requestHeaders.getAllRawHeaders():
# Prevent CVE-2015-0219
if b"_" in name:
@ -155,12 +151,13 @@ class WebRequest(http.Request):
self.root_path = self.unquote(value)
else:
self.clean_headers.append((name.lower(), value))
logger.debug("HTTP %s request for %s", self.method, self.reply_channel)
logger.debug("HTTP %s request for %s", self.method, self.consumer_channel)
self.content.seek(0, 0)
# Send message
try:
self.factory.channel_layer.send("http.request", {
"reply_channel": self.reply_channel,
# Run consumer
self.server.handle_message(
self.consumer_channel,
{
"type": "http.request",
# TODO: Correctly say if it's 1.1 or 1.0
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
"method": self.method.decode("ascii"),
@ -172,14 +169,90 @@ class WebRequest(http.Request):
"body": self.content.read(),
"client": self.client_addr,
"server": self.server_addr,
})
except self.factory.channel_layer.ChannelFull:
# Channel is too full; reject request with 503
self.basic_error(503, b"Service Unavailable", "Request queue full.")
},
)
except Exception:
logger.error(traceback.format_exc())
self.basic_error(500, b"Internal Server Error", "HTTP processing error")
def connectionLost(self, reason):
"""
Cleans up reply channel on close.
"""
if self.consumer_channel:
self.send_disconnect()
logger.debug("HTTP disconnect for %s", self.client_addr)
http.Request.connectionLost(self, reason)
self.server.discard_protocol(self)
def finish(self):
"""
Cleans up reply channel on close.
"""
if self.consumer_channel:
self.send_disconnect()
logger.debug("HTTP close for %s", self.client_addr)
http.Request.finish(self)
self.server.discard_protocol(self)
### Server reply callbacks
def handle_reply(self, message):
"""
Handles a reply from the client
"""
if message['type'] == "http.response":
if self._response_started:
raise ValueError("HTTP response has already been started")
self._response_started = True
if 'status' not in message:
raise ValueError("Specifying a status code is required for a Response message.")
# Set HTTP status code
self.setResponseCode(message['status'])
# Write headers
for header, value in message.get("headers", {}):
# Shim code from old ASGI version, can be removed after a while
if isinstance(header, six.text_type):
header = header.encode("latin1")
self.responseHeaders.addRawHeader(header, value)
logger.debug("HTTP %s response started for %s", message['status'], self.client_addr)
elif message['type'] == "http.response.content":
if not self._response_started:
raise ValueError("HTTP response has not yet been started")
else:
raise ValueError("Cannot handle message type %s!" % message['type'])
# Write out body
http.Request.write(self, message.get('content', b''))
# End if there's no more content
if not message.get("more_content", False):
self.finish()
logger.debug("HTTP response complete for %s", self.client_addr)
try:
self.server.log_action("http", "complete", {
"path": self.uri.decode("ascii"),
"status": self.code,
"method": self.method.decode("ascii"),
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
"time_taken": self.duration(),
"size": self.sentLength,
})
except Exception as e:
logging.error(traceback.format_exc())
else:
logger.debug("HTTP response chunk for %s", self.client_addr)
def check_timeouts(self):
"""
Called periodically to see if we should timeout something
"""
# Web timeout checking
if self.duration() > self.server.http_timeout:
self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.")
### Utility functions
@classmethod
def unquote(cls, value, plus_as_space=False):
"""
@ -198,81 +271,17 @@ class WebRequest(http.Request):
def send_disconnect(self):
"""
Sends a disconnect message on the http.disconnect channel.
Sends a http.disconnect message.
Useful only really for long-polling.
"""
# If we don't yet have a path, then don't send as we never opened.
if self.path:
try:
self.factory.channel_layer.send("http.disconnect", {
"reply_channel": self.reply_channel,
"path": self.unquote(self.path),
})
except self.factory.channel_layer.ChannelFull:
pass
def connectionLost(self, reason):
"""
Cleans up reply channel on close.
"""
if self.reply_channel and self.reply_channel in self.channel.factory.reply_protocols:
self.send_disconnect()
del self.channel.factory.reply_protocols[self.reply_channel]
logger.debug("HTTP disconnect for %s", self.reply_channel)
http.Request.connectionLost(self, reason)
def finish(self):
"""
Cleans up reply channel on close.
"""
if self.reply_channel and self.reply_channel in self.channel.factory.reply_protocols:
self.send_disconnect()
del self.channel.factory.reply_protocols[self.reply_channel]
logger.debug("HTTP close for %s", self.reply_channel)
http.Request.finish(self)
def serverResponse(self, message):
"""
Writes a received HTTP response back out to the transport.
"""
if not self._got_response_start:
self._got_response_start = True
if 'status' not in message:
raise ValueError("Specifying a status code is required for a Response message.")
# Set HTTP status code
self.setResponseCode(message['status'])
# Write headers
for header, value in message.get("headers", {}):
# Shim code from old ASGI version, can be removed after a while
if isinstance(header, six.text_type):
header = header.encode("latin1")
self.responseHeaders.addRawHeader(header, value)
logger.debug("HTTP %s response started for %s", message['status'], self.reply_channel)
else:
if 'status' in message:
raise ValueError("Got multiple Response messages for %s!" % self.reply_channel)
# Write out body
http.Request.write(self, message.get('content', b''))
# End if there's no more content
if not message.get("more_content", False):
self.finish()
logger.debug("HTTP response complete for %s", self.reply_channel)
try:
self.factory.log_action("http", "complete", {
"path": self.uri.decode("ascii"),
"status": self.code,
"method": self.method.decode("ascii"),
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
"time_taken": self.duration(),
"size": self.sentLength,
})
except Exception as e:
logging.error(traceback.format_exc())
else:
logger.debug("HTTP response chunk for %s", self.reply_channel)
self.server.handle_message(
self.consumer_channel,
{
"type": "http.disconnect",
},
)
def duration(self):
"""
@ -298,6 +307,12 @@ class WebRequest(http.Request):
}).encode("utf8"),
})
def __hash__(self):
return hash(id(self))
def __eq__(self, other):
return id(self) == id(other)
@implementer(IProtocolNegotiationFactory)
class HTTPFactory(http.HTTPFactory):
@ -308,30 +323,9 @@ class HTTPFactory(http.HTTPFactory):
routed appropriately.
"""
def __init__(self, channel_layer, action_logger=None, send_channel=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30, proxy_forwarded_address_header=None, proxy_forwarded_port_header=None, websocket_handshake_timeout=5):
def __init__(self, server):
http.HTTPFactory.__init__(self)
self.channel_layer = channel_layer
self.action_logger = action_logger
self.send_channel = send_channel
assert self.send_channel is not None
self.timeout = timeout
self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout
self.ping_interval = ping_interval
self.proxy_forwarded_address_header = proxy_forwarded_address_header
self.proxy_forwarded_port_header = proxy_forwarded_port_header
# We track all sub-protocols for response channel mapping
self.reply_protocols = {}
# Make a factory for WebSocket protocols
self.ws_factory = WebSocketFactory(self, protocols=ws_protocols, server='Daphne')
self.ws_factory.setProtocolOptions(
autoPingTimeout=ping_timeout,
allowNullOrigin=True,
openHandshakeTimeout=websocket_handshake_timeout
)
self.ws_factory.protocol = WebSocketProtocol
self.ws_factory.reply_protocols = self.reply_protocols
self.root_path = root_path
self.server = server
def buildProtocol(self, addr):
"""
@ -353,9 +347,6 @@ class HTTPFactory(http.HTTPFactory):
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
return self.send_channel + protocol_id
def reply_channels(self):
return self.reply_protocols.keys()
def dispatch_reply(self, channel, message):
# If we don't know about the channel, ignore it (likely a channel we
# used to have that's now in a group).
@ -408,13 +399,6 @@ class HTTPFactory(http.HTTPFactory):
else:
raise ValueError("Unknown protocol class")
def log_action(self, protocol, action, details):
"""
Dispatches to any registered action logger, if there is one.
"""
if self.action_logger:
self.action_logger(protocol, action, details)
def check_timeouts(self):
"""
Runs through all HTTP protocol instances and times them out if they've

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals
import collections
import logging
import random
import string
@ -9,8 +10,10 @@ from twisted.internet import reactor, defer
from twisted.internet.endpoints import serverFromString
from twisted.logger import globalLogBeginner, STDLibLogObserver
from twisted.web import http
from twisted.python.threadpool import ThreadPool
from .http_protocol import HTTPFactory
from .ws_protocol import WebSocketFactory
logger = logging.getLogger(__name__)
@ -20,11 +23,8 @@ class Server(object):
def __init__(
self,
channel_layer,
host=None,
port=None,
consumer,
endpoints=None,
unix_socket=None,
file_descriptor=None,
signal_handlers=True,
action_logger=None,
http_timeout=120,
@ -36,28 +36,14 @@ class Server(object):
root_path="",
proxy_forwarded_address_header=None,
proxy_forwarded_port_header=None,
force_sync=False,
verbosity=1,
websocket_handshake_timeout=5
):
self.channel_layer = channel_layer
self.consumer = consumer
self.endpoints = endpoints or []
if any([host, port, unix_socket, file_descriptor]):
warnings.warn('''
The host/port/unix_socket/file_descriptor keyword arguments to %s are deprecated.
''' % self.__class__.__name__, DeprecationWarning)
# build endpoint description strings from deprecated kwargs
self.endpoints = sorted(self.endpoints + build_endpoint_description_strings(
host=host,
port=port,
unix_socket=unix_socket,
file_descriptor=file_descriptor
))
if len(self.endpoints) == 0:
raise UserWarning("No endpoints. This server will not listen on anything.")
self.listeners = []
self.signal_handlers = signal_handlers
self.action_logger = action_logger
@ -73,29 +59,27 @@ class Server(object):
self.websocket_handshake_timeout = websocket_handshake_timeout
self.ws_protocols = ws_protocols
self.root_path = root_path
self.force_sync = force_sync
self.verbosity = verbosity
def run(self):
# Make the thread pool to run consumers in
# TODO: Configurable numbers of threads
self.pool = ThreadPool(name="consumers")
# Make the mapping of consumer instances to consumer channels
self.consumer_instances = {}
# A set of current Twisted protocol instances to manage
self.protocols = set()
# Create process-local channel prefixes
# TODO: Can we guarantee non-collision better?
process_id = "".join(random.choice(string.ascii_letters) for i in range(10))
self.send_channel = "daphne.response.%s!" % process_id
self.consumer_channel_prefix = "daphne.%s!" % process_id
# Make the factory
self.factory = HTTPFactory(
self.channel_layer,
action_logger=self.action_logger,
send_channel=self.send_channel,
timeout=self.http_timeout,
websocket_timeout=self.websocket_timeout,
websocket_connect_timeout=self.websocket_connect_timeout,
ping_interval=self.ping_interval,
ping_timeout=self.ping_timeout,
ws_protocols=self.ws_protocols,
root_path=self.root_path,
proxy_forwarded_address_header=self.proxy_forwarded_address_header,
proxy_forwarded_port_header=self.proxy_forwarded_port_header,
websocket_handshake_timeout=self.websocket_handshake_timeout
self.http_factory = HTTPFactory(self)
self.ws_factory = WebSocketFactory(self, protocols=self.ws_protocols, server='Daphne')
self.ws_factory.setProtocolOptions(
autoPingTimeout=self.ping_timeout,
allowNullOrigin=True,
openHandshakeTimeout=self.websocket_handshake_timeout
)
if self.verbosity <= 1:
# Redirect the Twisted log to nowhere
@ -109,28 +93,78 @@ class Server(object):
else:
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
if "twisted" in self.channel_layer.extensions and not self.force_sync:
logger.info("Using native Twisted mode on channel layer")
reactor.callLater(0, self.backend_reader_twisted)
else:
logger.info("Using busy-loop synchronous mode on channel layer")
reactor.callLater(0, self.backend_reader_sync)
# Kick off the various background loops
reactor.callLater(0, self.backend_reader)
reactor.callLater(2, self.timeout_checker)
for socket_description in self.endpoints:
logger.info("Listening on endpoint %s" % socket_description)
# Twisted requires str on python2 (not unicode) and str on python3 (not bytes)
ep = serverFromString(reactor, str(socket_description))
self.listeners.append(ep.listen(self.factory))
self.listeners.append(ep.listen(self.http_factory))
self.pool.start()
reactor.addSystemEventTrigger("before", "shutdown", self.pool.stop)
reactor.run(installSignalHandlers=self.signal_handlers)
def backend_reader_sync(self):
### Protocol handling
def add_protocol(self, protocol):
if protocol in self.protocols:
raise RuntimeError("Protocol %r was added to main list twice!" % protocol)
self.protocols.add(protocol)
def discard_protocol(self, protocol):
self.protocols.discard(protocol)
### Internal event/message handling
def create_consumer(self, protocol):
"""
Creates a new consumer instance that fronts a Protocol instance
for one of our supported protocols. Pass it the protocol,
and it will work out the type, supply appropriate callables, and
put it into the server's consumer pool.
It returns the consumer channel name, which is how you should refer
to the consumer instance.
"""
# Make sure the protocol defines a consumer type
assert protocol.consumer_type is not None
# Make it a consumer channel name
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
consumer_channel = self.consumer_channel_prefix + protocol_id
# Make an instance of the consumer
consumer_instance = self.consumer(
type=protocol.consumer_type,
reply=lambda message: self.handle_reply(protocol, message),
channel_layer=self.channel_layer,
consumer_channel=consumer_channel,
)
# Assign it by channel and return it
self.consumer_instances[consumer_channel] = consumer_instance
return consumer_channel
def handle_message(self, consumer_channel, message):
"""
Schedules the application instance to handle the given message.
"""
self.pool.callInThread(self.consumer_instances[consumer_channel], message)
def handle_reply(self, protocol, message):
"""
Schedules the reply to be handled by the protocol in the main thread
"""
reactor.callFromThread(reactor.callLater, 0, protocol.handle_reply, message)
### External event/message handling
def backend_reader(self):
"""
Runs as an-often-as-possible task with the reactor, unless there was
no result previously in which case we add a small delay.
"""
channels = [self.send_channel]
channels = [self.consumer_channel_prefix]
delay = 0
# Quit if reactor is stopping
if not reactor.running:
@ -147,75 +181,29 @@ class Server(object):
if channel:
# Deal with the message
try:
self.factory.dispatch_reply(channel, message)
self.handle_message(channel, message)
except Exception as e:
logger.error("HTTP/WS send decode error: %s" % e)
logger.error("Error handling external message: %s" % e)
else:
# If there's no messages, idle a little bit.
delay = 0.05
# We can't loop inside here as this is synchronous code.
reactor.callLater(delay, self.backend_reader_sync)
reactor.callLater(delay, self.backend_reader)
@defer.inlineCallbacks
def backend_reader_twisted(self):
"""
Runs as an-often-as-possible task with the reactor, unless there was
no result previously in which case we add a small delay.
"""
channels = [self.send_channel]
while True:
if not reactor.running:
logging.debug("Backend reader quitting due to reactor stop")
return
try:
channel, message = yield self.channel_layer.receive_twisted(channels)
except Exception as e:
logger.error('Error trying to receive messages: %s' % e)
yield self.sleep(5.00)
else:
# Deal with the message
if channel:
try:
self.factory.dispatch_reply(channel, message)
except Exception as e:
logger.error("HTTP/WS send decode error: %s" % e)
def sleep(self, delay):
d = defer.Deferred()
reactor.callLater(delay, d.callback, None)
return d
### Utility
def timeout_checker(self):
"""
Called periodically to enforce timeout rules on all connections.
Also checks pings at the same time.
"""
self.factory.check_timeouts()
for protocol in self.protocols:
protocol.check_timeouts()
reactor.callLater(2, self.timeout_checker)
def build_endpoint_description_strings(
host=None,
port=None,
unix_socket=None,
file_descriptor=None
):
def log_action(self, protocol, action, details):
"""
Build a list of twisted endpoint description strings that the server will listen on.
This is to streamline the generation of twisted endpoint description strings from easier
to use command line args such as host, port, unix sockets etc.
Dispatches to any registered action logger, if there is one.
"""
socket_descriptions = []
if host and port:
host = host.strip('[]').replace(':', '\:')
socket_descriptions.append('tcp:port=%d:interface=%s' % (int(port), host))
elif any([host, port]):
raise ValueError('TCP binding requires both port and host kwargs.')
if unix_socket:
socket_descriptions.append('unix:%s' % unix_socket)
if file_descriptor is not None:
socket_descriptions.append('fd:fileno=%d' % int(file_descriptor))
return socket_descriptions
if self.action_logger:
self.action_logger(protocol, action, details)

View File

@ -1,6 +1,20 @@
import sys
import importlib
from twisted.web.http_headers import Headers
def import_by_path(path):
"""
Given a dotted/colon path, like project.module:ClassName.callable,
returns the object at the end of the path.
"""
module_path, object_path = path.split(":", 1)
target = importlib.import_module(module_path)
for bit in object_path.split("."):
target = getattr(target, bit)
return target
def header_value(headers, header_name):
value = headers[header_name]
if isinstance(value, list):

View File

@ -23,9 +23,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
# If we should send no more messages (e.g. we error-closed the socket)
muted = False
def set_main_factory(self, main_factory):
self.main_factory = main_factory
self.channel_layer = self.main_factory.channel_layer
def __init__(self, *args, **kwargs):
super(WebSocketProtocol, self).__init__(*args, **kwargs)
self.server = self.factory.server_class
def onConnect(self, request):
self.request = request
@ -239,19 +239,29 @@ class WebSocketProtocol(WebSocketServerProtocol):
"""
return time.time() - self.socket_opened
def check_ping(self):
def check_timeouts(self):
"""
Checks to see if we should send a keepalive ping/deny socket connection
Called periodically to see if we should timeout something
"""
# Web timeout checking
if self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0:
self.serverClose()
# Ping check
# If we're still connecting, deny the connection
if self.state == self.STATE_CONNECTING:
if self.duration() > self.main_factory.websocket_connect_timeout:
if self.duration() > self.server.websocket_connect_timeout:
self.serverReject()
elif self.state == self.STATE_OPEN:
if (time.time() - self.last_data) > self.main_factory.ping_interval:
if (time.time() - self.last_data) > self.server.ping_interval:
self._sendAutoPing()
self.last_data = time.time()
def __hash__(self):
return hash(id(self))
def __eq__(self, other):
return id(self) == id(other)
class WebSocketFactory(WebSocketServerFactory):
"""
@ -260,9 +270,8 @@ class WebSocketFactory(WebSocketServerFactory):
to get reply ID info.
"""
def __init__(self, main_factory, *args, **kwargs):
self.main_factory = main_factory
WebSocketServerFactory.__init__(self, *args, **kwargs)
protocol = WebSocketProtocol
def log_action(self, *args, **kwargs):
self.main_factory.log_action(*args, **kwargs)
def __init__(self, server_class, *args, **kwargs):
self.server_class = server_class
WebSocketServerFactory.__init__(self, *args, **kwargs)