Initial refactor to get HTTP working in new style

This commit is contained in:
Andrew Godwin 2017-08-07 14:15:35 +10:00
parent f3b5d854ca
commit a656c9f4c6
7 changed files with 290 additions and 266 deletions

2
.gitignore vendored
View File

@ -7,3 +7,5 @@ build/
.hypothesis .hypothesis
.cache .cache
.eggs .eggs
test_layer*
test_consumer*

View File

@ -2,8 +2,10 @@ import sys
import argparse import argparse
import logging import logging
import importlib import importlib
from .server import Server, build_endpoint_description_strings from .server import Server
from .endpoints import build_endpoint_description_strings
from .access import AccessLogGenerator from .access import AccessLogGenerator
from .utils import import_by_path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -81,7 +83,7 @@ class CommandLineInterface(object):
'-t', '-t',
'--http-timeout', '--http-timeout',
type=int, type=int,
help='How long to wait for worker server before timing out HTTP connections', help='How long to wait for worker before timing out HTTP connections',
default=120, default=120,
) )
self.parser.add_argument( self.parser.add_argument(
@ -123,16 +125,19 @@ class CommandLineInterface(object):
action='store_true', action='store_true',
) )
self.parser.add_argument( self.parser.add_argument(
'--force-sync', '--threads',
dest='force_sync', help='Number of threads to run the application in',
action='store_true', type=int,
help='Force the server to use synchronous mode on its ASGI channel layer', default=2,
default=False,
) )
self.parser.add_argument( self.parser.add_argument(
'channel_layer', 'channel_layer',
help='The ASGI channel layer instance to use as path.to.module:instance.path', help='The ASGI channel layer instance to use as path.to.module:instance.path',
) )
self.parser.add_argument(
'consumer',
help='The consumer to dispatch to as path.to.module:instance.path',
)
self.server = None self.server = None
@ -168,13 +173,11 @@ class CommandLineInterface(object):
access_log_stream = open(args.access_log, "a", 1) access_log_stream = open(args.access_log, "a", 1)
elif args.verbosity >= 1: elif args.verbosity >= 1:
access_log_stream = sys.stdout access_log_stream = sys.stdout
# Import channel layer # Import application and channel_layer
sys.path.insert(0, ".") sys.path.insert(0, ".")
module_path, object_path = args.channel_layer.split(":", 1) channel_layer = import_by_path(args.channel_layer)
channel_layer = importlib.import_module(module_path) consumer = import_by_path(args.consumer)
for bit in object_path.split("."): # Set up port/host bindings
channel_layer = getattr(channel_layer, bit)
if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]): if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]):
# no advanced binding options passed, patch in defaults # no advanced binding options passed, patch in defaults
args.host = DEFAULT_HOST args.host = DEFAULT_HOST
@ -183,8 +186,7 @@ class CommandLineInterface(object):
args.port = DEFAULT_PORT args.port = DEFAULT_PORT
elif args.port and not args.host: elif args.port and not args.host:
args.host = DEFAULT_HOST args.host = DEFAULT_HOST
# Build endpoint description strings from (optional) cli arguments
# build endpoint description strings from (optional) cli arguments
endpoints = build_endpoint_description_strings( endpoints = build_endpoint_description_strings(
host=args.host, host=args.host,
port=args.port, port=args.port,
@ -194,13 +196,14 @@ class CommandLineInterface(object):
endpoints = sorted( endpoints = sorted(
args.socket_strings + endpoints args.socket_strings + endpoints
) )
# Start the server
logger.info( logger.info(
'Starting server at %s, channel layer %s.' % 'Starting server at %s, channel layer %s.' %
(', '.join(endpoints), args.channel_layer) (', '.join(endpoints), args.channel_layer)
) )
self.server = Server( self.server = Server(
channel_layer=channel_layer, channel_layer=channel_layer,
consumer=consumer,
endpoints=endpoints, endpoints=endpoints,
http_timeout=args.http_timeout, http_timeout=args.http_timeout,
ping_interval=args.ping_interval, ping_interval=args.ping_interval,
@ -213,6 +216,5 @@ class CommandLineInterface(object):
verbosity=args.verbosity, verbosity=args.verbosity,
proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None, proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None,
proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None, proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None,
force_sync=args.force_sync,
) )
self.server.run() self.server.run()

25
daphne/endpoints.py Normal file
View File

@ -0,0 +1,25 @@
def build_endpoint_description_strings(
host=None,
port=None,
unix_socket=None,
file_descriptor=None
):
"""
Build a list of twisted endpoint description strings that the server will listen on.
This is to streamline the generation of twisted endpoint description strings from easier
to use command line args such as host, port, unix sockets etc.
"""
socket_descriptions = []
if host and port:
host = host.strip('[]').replace(':', '\:')
socket_descriptions.append('tcp:port=%d:interface=%s' % (int(port), host))
elif any([host, port]):
raise ValueError('TCP binding requires both port and host kwargs.')
if unix_socket:
socket_descriptions.append('unix:%s' % unix_socket)
if file_descriptor is not None:
socket_descriptions.append('fd:fileno=%d' % int(file_descriptor))
return socket_descriptions

View File

@ -29,6 +29,8 @@ class WebRequest(http.Request):
GET and POST out. GET and POST out.
""" """
consumer_type = "http"
error_template = """ error_template = """
<html> <html>
<head> <head>
@ -51,18 +53,17 @@ class WebRequest(http.Request):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
try: try:
http.Request.__init__(self, *args, **kwargs) http.Request.__init__(self, *args, **kwargs)
# Easy factory link # Easy server link
self.factory = self.channel.factory self.server = self.channel.factory.server
# Make a name for our reply channel self.consumer_channel = None
self.reply_channel = self.factory.make_send_channel() self._response_started = False
# Tell factory we're that channel's client self.server.add_protocol(self)
self.last_keepalive = time.time()
self.factory.reply_protocols[self.reply_channel] = self
self._got_response_start = False
except Exception: except Exception:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise raise
### Twisted progress callbacks
def process(self): def process(self):
try: try:
self.request_start = time.time() self.request_start = time.time()
@ -79,15 +80,14 @@ class WebRequest(http.Request):
else: else:
self.client_addr = None self.client_addr = None
self.server_addr = None self.server_addr = None
# See if we need to get the address from a proxy header instead
if self.factory.proxy_forwarded_address_header: if self.server.proxy_forwarded_address_header:
self.client_addr = parse_x_forwarded_for( self.client_addr = parse_x_forwarded_for(
self.requestHeaders, self.requestHeaders,
self.factory.proxy_forwarded_address_header, self.server.proxy_forwarded_address_header,
self.factory.proxy_forwarded_port_header, self.server.proxy_forwarded_port_header,
self.client_addr self.client_addr
) )
# Check for unicodeish path (or it'll crash when trying to parse) # Check for unicodeish path (or it'll crash when trying to parse)
try: try:
self.path.decode("ascii") self.path.decode("ascii")
@ -102,7 +102,7 @@ class WebRequest(http.Request):
# Is it WebSocket? IS IT?! # Is it WebSocket? IS IT?!
if upgrade_header and upgrade_header.lower() == b"websocket": if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to # Make WebSocket protocol to hand off to
protocol = self.factory.ws_factory.buildProtocol(self.transport.getPeer()) protocol = self.server.ws_factory.buildProtocol(self.transport.getPeer())
if not protocol: if not protocol:
# If protocol creation fails, we signal "internal server error" # If protocol creation fails, we signal "internal server error"
self.setResponseCode(500) self.setResponseCode(500)
@ -111,7 +111,6 @@ class WebRequest(http.Request):
# Give it the raw query string # Give it the raw query string
protocol._raw_query_string = self.query_string protocol._raw_query_string = self.query_string
# Port across transport # Port across transport
protocol.set_main_factory(self.factory)
transport, self.transport = self.transport, None transport, self.transport = self.transport, None
if isinstance(transport, ProtocolWrapper): if isinstance(transport, ProtocolWrapper):
# i.e. TLS is a wrapping protocol # i.e. TLS is a wrapping protocol
@ -127,12 +126,7 @@ class WebRequest(http.Request):
data += self.content.read() data += self.content.read()
protocol.dataReceived(data) protocol.dataReceived(data)
# Remove our HTTP reply channel association # Remove our HTTP reply channel association
if hasattr(protocol, "reply_channel"): logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
else:
logger.debug("Connection %s did not get successful WS handshake.", self.reply_channel)
del self.factory.reply_protocols[self.reply_channel]
self.reply_channel = None
# Resume the producer so we keep getting data, if it's available as a method # Resume the producer so we keep getting data, if it's available as a method
# 17.1 version # 17.1 version
if hasattr(self.channel, "_networkProducer"): if hasattr(self.channel, "_networkProducer"):
@ -143,9 +137,11 @@ class WebRequest(http.Request):
# Boring old HTTP. # Boring old HTTP.
else: else:
# Create consumer to handle this connection
self.consumer_channel = self.server.create_consumer(self)
# Sanitize and decode headers, potentially extracting root path # Sanitize and decode headers, potentially extracting root path
self.clean_headers = [] self.clean_headers = []
self.root_path = self.factory.root_path self.root_path = self.server.root_path
for name, values in self.requestHeaders.getAllRawHeaders(): for name, values in self.requestHeaders.getAllRawHeaders():
# Prevent CVE-2015-0219 # Prevent CVE-2015-0219
if b"_" in name: if b"_" in name:
@ -155,12 +151,13 @@ class WebRequest(http.Request):
self.root_path = self.unquote(value) self.root_path = self.unquote(value)
else: else:
self.clean_headers.append((name.lower(), value)) self.clean_headers.append((name.lower(), value))
logger.debug("HTTP %s request for %s", self.method, self.reply_channel) logger.debug("HTTP %s request for %s", self.method, self.consumer_channel)
self.content.seek(0, 0) self.content.seek(0, 0)
# Send message # Run consumer
try: self.server.handle_message(
self.factory.channel_layer.send("http.request", { self.consumer_channel,
"reply_channel": self.reply_channel, {
"type": "http.request",
# TODO: Correctly say if it's 1.1 or 1.0 # TODO: Correctly say if it's 1.1 or 1.0
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"), "http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
"method": self.method.decode("ascii"), "method": self.method.decode("ascii"),
@ -172,14 +169,90 @@ class WebRequest(http.Request):
"body": self.content.read(), "body": self.content.read(),
"client": self.client_addr, "client": self.client_addr,
"server": self.server_addr, "server": self.server_addr,
}) },
except self.factory.channel_layer.ChannelFull: )
# Channel is too full; reject request with 503
self.basic_error(503, b"Service Unavailable", "Request queue full.")
except Exception: except Exception:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self.basic_error(500, b"Internal Server Error", "HTTP processing error") self.basic_error(500, b"Internal Server Error", "HTTP processing error")
def connectionLost(self, reason):
"""
Cleans up reply channel on close.
"""
if self.consumer_channel:
self.send_disconnect()
logger.debug("HTTP disconnect for %s", self.client_addr)
http.Request.connectionLost(self, reason)
self.server.discard_protocol(self)
def finish(self):
"""
Cleans up reply channel on close.
"""
if self.consumer_channel:
self.send_disconnect()
logger.debug("HTTP close for %s", self.client_addr)
http.Request.finish(self)
self.server.discard_protocol(self)
### Server reply callbacks
def handle_reply(self, message):
"""
Handles a reply from the client
"""
if message['type'] == "http.response":
if self._response_started:
raise ValueError("HTTP response has already been started")
self._response_started = True
if 'status' not in message:
raise ValueError("Specifying a status code is required for a Response message.")
# Set HTTP status code
self.setResponseCode(message['status'])
# Write headers
for header, value in message.get("headers", {}):
# Shim code from old ASGI version, can be removed after a while
if isinstance(header, six.text_type):
header = header.encode("latin1")
self.responseHeaders.addRawHeader(header, value)
logger.debug("HTTP %s response started for %s", message['status'], self.client_addr)
elif message['type'] == "http.response.content":
if not self._response_started:
raise ValueError("HTTP response has not yet been started")
else:
raise ValueError("Cannot handle message type %s!" % message['type'])
# Write out body
http.Request.write(self, message.get('content', b''))
# End if there's no more content
if not message.get("more_content", False):
self.finish()
logger.debug("HTTP response complete for %s", self.client_addr)
try:
self.server.log_action("http", "complete", {
"path": self.uri.decode("ascii"),
"status": self.code,
"method": self.method.decode("ascii"),
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
"time_taken": self.duration(),
"size": self.sentLength,
})
except Exception as e:
logging.error(traceback.format_exc())
else:
logger.debug("HTTP response chunk for %s", self.client_addr)
def check_timeouts(self):
"""
Called periodically to see if we should timeout something
"""
# Web timeout checking
if self.duration() > self.server.http_timeout:
self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.")
### Utility functions
@classmethod @classmethod
def unquote(cls, value, plus_as_space=False): def unquote(cls, value, plus_as_space=False):
""" """
@ -198,81 +271,17 @@ class WebRequest(http.Request):
def send_disconnect(self): def send_disconnect(self):
""" """
Sends a disconnect message on the http.disconnect channel. Sends a http.disconnect message.
Useful only really for long-polling. Useful only really for long-polling.
""" """
# If we don't yet have a path, then don't send as we never opened. # If we don't yet have a path, then don't send as we never opened.
if self.path: if self.path:
try: self.server.handle_message(
self.factory.channel_layer.send("http.disconnect", { self.consumer_channel,
"reply_channel": self.reply_channel, {
"path": self.unquote(self.path), "type": "http.disconnect",
}) },
except self.factory.channel_layer.ChannelFull: )
pass
def connectionLost(self, reason):
"""
Cleans up reply channel on close.
"""
if self.reply_channel and self.reply_channel in self.channel.factory.reply_protocols:
self.send_disconnect()
del self.channel.factory.reply_protocols[self.reply_channel]
logger.debug("HTTP disconnect for %s", self.reply_channel)
http.Request.connectionLost(self, reason)
def finish(self):
"""
Cleans up reply channel on close.
"""
if self.reply_channel and self.reply_channel in self.channel.factory.reply_protocols:
self.send_disconnect()
del self.channel.factory.reply_protocols[self.reply_channel]
logger.debug("HTTP close for %s", self.reply_channel)
http.Request.finish(self)
def serverResponse(self, message):
"""
Writes a received HTTP response back out to the transport.
"""
if not self._got_response_start:
self._got_response_start = True
if 'status' not in message:
raise ValueError("Specifying a status code is required for a Response message.")
# Set HTTP status code
self.setResponseCode(message['status'])
# Write headers
for header, value in message.get("headers", {}):
# Shim code from old ASGI version, can be removed after a while
if isinstance(header, six.text_type):
header = header.encode("latin1")
self.responseHeaders.addRawHeader(header, value)
logger.debug("HTTP %s response started for %s", message['status'], self.reply_channel)
else:
if 'status' in message:
raise ValueError("Got multiple Response messages for %s!" % self.reply_channel)
# Write out body
http.Request.write(self, message.get('content', b''))
# End if there's no more content
if not message.get("more_content", False):
self.finish()
logger.debug("HTTP response complete for %s", self.reply_channel)
try:
self.factory.log_action("http", "complete", {
"path": self.uri.decode("ascii"),
"status": self.code,
"method": self.method.decode("ascii"),
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
"time_taken": self.duration(),
"size": self.sentLength,
})
except Exception as e:
logging.error(traceback.format_exc())
else:
logger.debug("HTTP response chunk for %s", self.reply_channel)
def duration(self): def duration(self):
""" """
@ -298,6 +307,12 @@ class WebRequest(http.Request):
}).encode("utf8"), }).encode("utf8"),
}) })
def __hash__(self):
return hash(id(self))
def __eq__(self, other):
return id(self) == id(other)
@implementer(IProtocolNegotiationFactory) @implementer(IProtocolNegotiationFactory)
class HTTPFactory(http.HTTPFactory): class HTTPFactory(http.HTTPFactory):
@ -308,30 +323,9 @@ class HTTPFactory(http.HTTPFactory):
routed appropriately. routed appropriately.
""" """
def __init__(self, channel_layer, action_logger=None, send_channel=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30, proxy_forwarded_address_header=None, proxy_forwarded_port_header=None, websocket_handshake_timeout=5): def __init__(self, server):
http.HTTPFactory.__init__(self) http.HTTPFactory.__init__(self)
self.channel_layer = channel_layer self.server = server
self.action_logger = action_logger
self.send_channel = send_channel
assert self.send_channel is not None
self.timeout = timeout
self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout
self.ping_interval = ping_interval
self.proxy_forwarded_address_header = proxy_forwarded_address_header
self.proxy_forwarded_port_header = proxy_forwarded_port_header
# We track all sub-protocols for response channel mapping
self.reply_protocols = {}
# Make a factory for WebSocket protocols
self.ws_factory = WebSocketFactory(self, protocols=ws_protocols, server='Daphne')
self.ws_factory.setProtocolOptions(
autoPingTimeout=ping_timeout,
allowNullOrigin=True,
openHandshakeTimeout=websocket_handshake_timeout
)
self.ws_factory.protocol = WebSocketProtocol
self.ws_factory.reply_protocols = self.reply_protocols
self.root_path = root_path
def buildProtocol(self, addr): def buildProtocol(self, addr):
""" """
@ -353,9 +347,6 @@ class HTTPFactory(http.HTTPFactory):
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10)) protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
return self.send_channel + protocol_id return self.send_channel + protocol_id
def reply_channels(self):
return self.reply_protocols.keys()
def dispatch_reply(self, channel, message): def dispatch_reply(self, channel, message):
# If we don't know about the channel, ignore it (likely a channel we # If we don't know about the channel, ignore it (likely a channel we
# used to have that's now in a group). # used to have that's now in a group).
@ -408,13 +399,6 @@ class HTTPFactory(http.HTTPFactory):
else: else:
raise ValueError("Unknown protocol class") raise ValueError("Unknown protocol class")
def log_action(self, protocol, action, details):
"""
Dispatches to any registered action logger, if there is one.
"""
if self.action_logger:
self.action_logger(protocol, action, details)
def check_timeouts(self): def check_timeouts(self):
""" """
Runs through all HTTP protocol instances and times them out if they've Runs through all HTTP protocol instances and times them out if they've

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import collections
import logging import logging
import random import random
import string import string
@ -9,8 +10,10 @@ from twisted.internet import reactor, defer
from twisted.internet.endpoints import serverFromString from twisted.internet.endpoints import serverFromString
from twisted.logger import globalLogBeginner, STDLibLogObserver from twisted.logger import globalLogBeginner, STDLibLogObserver
from twisted.web import http from twisted.web import http
from twisted.python.threadpool import ThreadPool
from .http_protocol import HTTPFactory from .http_protocol import HTTPFactory
from .ws_protocol import WebSocketFactory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,11 +23,8 @@ class Server(object):
def __init__( def __init__(
self, self,
channel_layer, channel_layer,
host=None, consumer,
port=None,
endpoints=None, endpoints=None,
unix_socket=None,
file_descriptor=None,
signal_handlers=True, signal_handlers=True,
action_logger=None, action_logger=None,
http_timeout=120, http_timeout=120,
@ -36,28 +36,14 @@ class Server(object):
root_path="", root_path="",
proxy_forwarded_address_header=None, proxy_forwarded_address_header=None,
proxy_forwarded_port_header=None, proxy_forwarded_port_header=None,
force_sync=False,
verbosity=1, verbosity=1,
websocket_handshake_timeout=5 websocket_handshake_timeout=5
): ):
self.channel_layer = channel_layer self.channel_layer = channel_layer
self.consumer = consumer
self.endpoints = endpoints or [] self.endpoints = endpoints or []
if any([host, port, unix_socket, file_descriptor]):
warnings.warn('''
The host/port/unix_socket/file_descriptor keyword arguments to %s are deprecated.
''' % self.__class__.__name__, DeprecationWarning)
# build endpoint description strings from deprecated kwargs
self.endpoints = sorted(self.endpoints + build_endpoint_description_strings(
host=host,
port=port,
unix_socket=unix_socket,
file_descriptor=file_descriptor
))
if len(self.endpoints) == 0: if len(self.endpoints) == 0:
raise UserWarning("No endpoints. This server will not listen on anything.") raise UserWarning("No endpoints. This server will not listen on anything.")
self.listeners = [] self.listeners = []
self.signal_handlers = signal_handlers self.signal_handlers = signal_handlers
self.action_logger = action_logger self.action_logger = action_logger
@ -73,29 +59,27 @@ class Server(object):
self.websocket_handshake_timeout = websocket_handshake_timeout self.websocket_handshake_timeout = websocket_handshake_timeout
self.ws_protocols = ws_protocols self.ws_protocols = ws_protocols
self.root_path = root_path self.root_path = root_path
self.force_sync = force_sync
self.verbosity = verbosity self.verbosity = verbosity
def run(self): def run(self):
# Make the thread pool to run consumers in
# TODO: Configurable numbers of threads
self.pool = ThreadPool(name="consumers")
# Make the mapping of consumer instances to consumer channels
self.consumer_instances = {}
# A set of current Twisted protocol instances to manage
self.protocols = set()
# Create process-local channel prefixes # Create process-local channel prefixes
# TODO: Can we guarantee non-collision better? # TODO: Can we guarantee non-collision better?
process_id = "".join(random.choice(string.ascii_letters) for i in range(10)) process_id = "".join(random.choice(string.ascii_letters) for i in range(10))
self.send_channel = "daphne.response.%s!" % process_id self.consumer_channel_prefix = "daphne.%s!" % process_id
# Make the factory # Make the factory
self.factory = HTTPFactory( self.http_factory = HTTPFactory(self)
self.channel_layer, self.ws_factory = WebSocketFactory(self, protocols=self.ws_protocols, server='Daphne')
action_logger=self.action_logger, self.ws_factory.setProtocolOptions(
send_channel=self.send_channel, autoPingTimeout=self.ping_timeout,
timeout=self.http_timeout, allowNullOrigin=True,
websocket_timeout=self.websocket_timeout, openHandshakeTimeout=self.websocket_handshake_timeout
websocket_connect_timeout=self.websocket_connect_timeout,
ping_interval=self.ping_interval,
ping_timeout=self.ping_timeout,
ws_protocols=self.ws_protocols,
root_path=self.root_path,
proxy_forwarded_address_header=self.proxy_forwarded_address_header,
proxy_forwarded_port_header=self.proxy_forwarded_port_header,
websocket_handshake_timeout=self.websocket_handshake_timeout
) )
if self.verbosity <= 1: if self.verbosity <= 1:
# Redirect the Twisted log to nowhere # Redirect the Twisted log to nowhere
@ -109,28 +93,78 @@ class Server(object):
else: else:
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)") logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
if "twisted" in self.channel_layer.extensions and not self.force_sync: # Kick off the various background loops
logger.info("Using native Twisted mode on channel layer") reactor.callLater(0, self.backend_reader)
reactor.callLater(0, self.backend_reader_twisted)
else:
logger.info("Using busy-loop synchronous mode on channel layer")
reactor.callLater(0, self.backend_reader_sync)
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)
for socket_description in self.endpoints: for socket_description in self.endpoints:
logger.info("Listening on endpoint %s" % socket_description) logger.info("Listening on endpoint %s" % socket_description)
# Twisted requires str on python2 (not unicode) and str on python3 (not bytes) # Twisted requires str on python2 (not unicode) and str on python3 (not bytes)
ep = serverFromString(reactor, str(socket_description)) ep = serverFromString(reactor, str(socket_description))
self.listeners.append(ep.listen(self.factory)) self.listeners.append(ep.listen(self.http_factory))
self.pool.start()
reactor.addSystemEventTrigger("before", "shutdown", self.pool.stop)
reactor.run(installSignalHandlers=self.signal_handlers) reactor.run(installSignalHandlers=self.signal_handlers)
def backend_reader_sync(self): ### Protocol handling
def add_protocol(self, protocol):
if protocol in self.protocols:
raise RuntimeError("Protocol %r was added to main list twice!" % protocol)
self.protocols.add(protocol)
def discard_protocol(self, protocol):
self.protocols.discard(protocol)
### Internal event/message handling
def create_consumer(self, protocol):
"""
Creates a new consumer instance that fronts a Protocol instance
for one of our supported protocols. Pass it the protocol,
and it will work out the type, supply appropriate callables, and
put it into the server's consumer pool.
It returns the consumer channel name, which is how you should refer
to the consumer instance.
"""
# Make sure the protocol defines a consumer type
assert protocol.consumer_type is not None
# Make it a consumer channel name
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
consumer_channel = self.consumer_channel_prefix + protocol_id
# Make an instance of the consumer
consumer_instance = self.consumer(
type=protocol.consumer_type,
reply=lambda message: self.handle_reply(protocol, message),
channel_layer=self.channel_layer,
consumer_channel=consumer_channel,
)
# Assign it by channel and return it
self.consumer_instances[consumer_channel] = consumer_instance
return consumer_channel
def handle_message(self, consumer_channel, message):
"""
Schedules the application instance to handle the given message.
"""
self.pool.callInThread(self.consumer_instances[consumer_channel], message)
def handle_reply(self, protocol, message):
"""
Schedules the reply to be handled by the protocol in the main thread
"""
reactor.callFromThread(reactor.callLater, 0, protocol.handle_reply, message)
### External event/message handling
def backend_reader(self):
""" """
Runs as an-often-as-possible task with the reactor, unless there was Runs as an-often-as-possible task with the reactor, unless there was
no result previously in which case we add a small delay. no result previously in which case we add a small delay.
""" """
channels = [self.send_channel] channels = [self.consumer_channel_prefix]
delay = 0 delay = 0
# Quit if reactor is stopping # Quit if reactor is stopping
if not reactor.running: if not reactor.running:
@ -147,75 +181,29 @@ class Server(object):
if channel: if channel:
# Deal with the message # Deal with the message
try: try:
self.factory.dispatch_reply(channel, message) self.handle_message(channel, message)
except Exception as e: except Exception as e:
logger.error("HTTP/WS send decode error: %s" % e) logger.error("Error handling external message: %s" % e)
else: else:
# If there's no messages, idle a little bit. # If there's no messages, idle a little bit.
delay = 0.05 delay = 0.05
# We can't loop inside here as this is synchronous code. # We can't loop inside here as this is synchronous code.
reactor.callLater(delay, self.backend_reader_sync) reactor.callLater(delay, self.backend_reader)
@defer.inlineCallbacks ### Utility
def backend_reader_twisted(self):
"""
Runs as an-often-as-possible task with the reactor, unless there was
no result previously in which case we add a small delay.
"""
channels = [self.send_channel]
while True:
if not reactor.running:
logging.debug("Backend reader quitting due to reactor stop")
return
try:
channel, message = yield self.channel_layer.receive_twisted(channels)
except Exception as e:
logger.error('Error trying to receive messages: %s' % e)
yield self.sleep(5.00)
else:
# Deal with the message
if channel:
try:
self.factory.dispatch_reply(channel, message)
except Exception as e:
logger.error("HTTP/WS send decode error: %s" % e)
def sleep(self, delay):
d = defer.Deferred()
reactor.callLater(delay, d.callback, None)
return d
def timeout_checker(self): def timeout_checker(self):
""" """
Called periodically to enforce timeout rules on all connections. Called periodically to enforce timeout rules on all connections.
Also checks pings at the same time. Also checks pings at the same time.
""" """
self.factory.check_timeouts() for protocol in self.protocols:
protocol.check_timeouts()
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)
def log_action(self, protocol, action, details):
def build_endpoint_description_strings( """
host=None, Dispatches to any registered action logger, if there is one.
port=None, """
unix_socket=None, if self.action_logger:
file_descriptor=None self.action_logger(protocol, action, details)
):
"""
Build a list of twisted endpoint description strings that the server will listen on.
This is to streamline the generation of twisted endpoint description strings from easier
to use command line args such as host, port, unix sockets etc.
"""
socket_descriptions = []
if host and port:
host = host.strip('[]').replace(':', '\:')
socket_descriptions.append('tcp:port=%d:interface=%s' % (int(port), host))
elif any([host, port]):
raise ValueError('TCP binding requires both port and host kwargs.')
if unix_socket:
socket_descriptions.append('unix:%s' % unix_socket)
if file_descriptor is not None:
socket_descriptions.append('fd:fileno=%d' % int(file_descriptor))
return socket_descriptions

View File

@ -1,6 +1,20 @@
import sys
import importlib
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
def import_by_path(path):
"""
Given a dotted/colon path, like project.module:ClassName.callable,
returns the object at the end of the path.
"""
module_path, object_path = path.split(":", 1)
target = importlib.import_module(module_path)
for bit in object_path.split("."):
target = getattr(target, bit)
return target
def header_value(headers, header_name): def header_value(headers, header_name):
value = headers[header_name] value = headers[header_name]
if isinstance(value, list): if isinstance(value, list):

View File

@ -23,9 +23,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
# If we should send no more messages (e.g. we error-closed the socket) # If we should send no more messages (e.g. we error-closed the socket)
muted = False muted = False
def set_main_factory(self, main_factory): def __init__(self, *args, **kwargs):
self.main_factory = main_factory super(WebSocketProtocol, self).__init__(*args, **kwargs)
self.channel_layer = self.main_factory.channel_layer self.server = self.factory.server_class
def onConnect(self, request): def onConnect(self, request):
self.request = request self.request = request
@ -239,19 +239,29 @@ class WebSocketProtocol(WebSocketServerProtocol):
""" """
return time.time() - self.socket_opened return time.time() - self.socket_opened
def check_ping(self): def check_timeouts(self):
""" """
Checks to see if we should send a keepalive ping/deny socket connection Called periodically to see if we should timeout something
""" """
# Web timeout checking
if self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0:
self.serverClose()
# Ping check
# If we're still connecting, deny the connection # If we're still connecting, deny the connection
if self.state == self.STATE_CONNECTING: if self.state == self.STATE_CONNECTING:
if self.duration() > self.main_factory.websocket_connect_timeout: if self.duration() > self.server.websocket_connect_timeout:
self.serverReject() self.serverReject()
elif self.state == self.STATE_OPEN: elif self.state == self.STATE_OPEN:
if (time.time() - self.last_data) > self.main_factory.ping_interval: if (time.time() - self.last_data) > self.server.ping_interval:
self._sendAutoPing() self._sendAutoPing()
self.last_data = time.time() self.last_data = time.time()
def __hash__(self):
return hash(id(self))
def __eq__(self, other):
return id(self) == id(other)
class WebSocketFactory(WebSocketServerFactory): class WebSocketFactory(WebSocketServerFactory):
""" """
@ -260,9 +270,8 @@ class WebSocketFactory(WebSocketServerFactory):
to get reply ID info. to get reply ID info.
""" """
def __init__(self, main_factory, *args, **kwargs): protocol = WebSocketProtocol
self.main_factory = main_factory
WebSocketServerFactory.__init__(self, *args, **kwargs)
def log_action(self, *args, **kwargs): def __init__(self, server_class, *args, **kwargs):
self.main_factory.log_action(*args, **kwargs) self.server_class = server_class
WebSocketServerFactory.__init__(self, *args, **kwargs)