mirror of
https://github.com/django/daphne.git
synced 2024-11-21 15:36:33 +03:00
Initial refactor to get HTTP working in new style
This commit is contained in:
parent
f3b5d854ca
commit
a656c9f4c6
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -7,3 +7,5 @@ build/
|
|||
.hypothesis
|
||||
.cache
|
||||
.eggs
|
||||
test_layer*
|
||||
test_consumer*
|
||||
|
|
|
@ -2,8 +2,10 @@ import sys
|
|||
import argparse
|
||||
import logging
|
||||
import importlib
|
||||
from .server import Server, build_endpoint_description_strings
|
||||
from .server import Server
|
||||
from .endpoints import build_endpoint_description_strings
|
||||
from .access import AccessLogGenerator
|
||||
from .utils import import_by_path
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -81,7 +83,7 @@ class CommandLineInterface(object):
|
|||
'-t',
|
||||
'--http-timeout',
|
||||
type=int,
|
||||
help='How long to wait for worker server before timing out HTTP connections',
|
||||
help='How long to wait for worker before timing out HTTP connections',
|
||||
default=120,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
|
@ -123,16 +125,19 @@ class CommandLineInterface(object):
|
|||
action='store_true',
|
||||
)
|
||||
self.parser.add_argument(
|
||||
'--force-sync',
|
||||
dest='force_sync',
|
||||
action='store_true',
|
||||
help='Force the server to use synchronous mode on its ASGI channel layer',
|
||||
default=False,
|
||||
'--threads',
|
||||
help='Number of threads to run the application in',
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
'channel_layer',
|
||||
help='The ASGI channel layer instance to use as path.to.module:instance.path',
|
||||
)
|
||||
self.parser.add_argument(
|
||||
'consumer',
|
||||
help='The consumer to dispatch to as path.to.module:instance.path',
|
||||
)
|
||||
|
||||
self.server = None
|
||||
|
||||
|
@ -168,13 +173,11 @@ class CommandLineInterface(object):
|
|||
access_log_stream = open(args.access_log, "a", 1)
|
||||
elif args.verbosity >= 1:
|
||||
access_log_stream = sys.stdout
|
||||
# Import channel layer
|
||||
# Import application and channel_layer
|
||||
sys.path.insert(0, ".")
|
||||
module_path, object_path = args.channel_layer.split(":", 1)
|
||||
channel_layer = importlib.import_module(module_path)
|
||||
for bit in object_path.split("."):
|
||||
channel_layer = getattr(channel_layer, bit)
|
||||
|
||||
channel_layer = import_by_path(args.channel_layer)
|
||||
consumer = import_by_path(args.consumer)
|
||||
# Set up port/host bindings
|
||||
if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]):
|
||||
# no advanced binding options passed, patch in defaults
|
||||
args.host = DEFAULT_HOST
|
||||
|
@ -183,8 +186,7 @@ class CommandLineInterface(object):
|
|||
args.port = DEFAULT_PORT
|
||||
elif args.port and not args.host:
|
||||
args.host = DEFAULT_HOST
|
||||
|
||||
# build endpoint description strings from (optional) cli arguments
|
||||
# Build endpoint description strings from (optional) cli arguments
|
||||
endpoints = build_endpoint_description_strings(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
|
@ -194,13 +196,14 @@ class CommandLineInterface(object):
|
|||
endpoints = sorted(
|
||||
args.socket_strings + endpoints
|
||||
)
|
||||
# Start the server
|
||||
logger.info(
|
||||
'Starting server at %s, channel layer %s.' %
|
||||
(', '.join(endpoints), args.channel_layer)
|
||||
)
|
||||
|
||||
self.server = Server(
|
||||
channel_layer=channel_layer,
|
||||
consumer=consumer,
|
||||
endpoints=endpoints,
|
||||
http_timeout=args.http_timeout,
|
||||
ping_interval=args.ping_interval,
|
||||
|
@ -213,6 +216,5 @@ class CommandLineInterface(object):
|
|||
verbosity=args.verbosity,
|
||||
proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None,
|
||||
proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None,
|
||||
force_sync=args.force_sync,
|
||||
)
|
||||
self.server.run()
|
||||
|
|
25
daphne/endpoints.py
Normal file
25
daphne/endpoints.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
def build_endpoint_description_strings(
|
||||
host=None,
|
||||
port=None,
|
||||
unix_socket=None,
|
||||
file_descriptor=None
|
||||
):
|
||||
"""
|
||||
Build a list of twisted endpoint description strings that the server will listen on.
|
||||
This is to streamline the generation of twisted endpoint description strings from easier
|
||||
to use command line args such as host, port, unix sockets etc.
|
||||
"""
|
||||
socket_descriptions = []
|
||||
if host and port:
|
||||
host = host.strip('[]').replace(':', '\:')
|
||||
socket_descriptions.append('tcp:port=%d:interface=%s' % (int(port), host))
|
||||
elif any([host, port]):
|
||||
raise ValueError('TCP binding requires both port and host kwargs.')
|
||||
|
||||
if unix_socket:
|
||||
socket_descriptions.append('unix:%s' % unix_socket)
|
||||
|
||||
if file_descriptor is not None:
|
||||
socket_descriptions.append('fd:fileno=%d' % int(file_descriptor))
|
||||
|
||||
return socket_descriptions
|
|
@ -29,6 +29,8 @@ class WebRequest(http.Request):
|
|||
GET and POST out.
|
||||
"""
|
||||
|
||||
consumer_type = "http"
|
||||
|
||||
error_template = """
|
||||
<html>
|
||||
<head>
|
||||
|
@ -51,18 +53,17 @@ class WebRequest(http.Request):
|
|||
def __init__(self, *args, **kwargs):
|
||||
try:
|
||||
http.Request.__init__(self, *args, **kwargs)
|
||||
# Easy factory link
|
||||
self.factory = self.channel.factory
|
||||
# Make a name for our reply channel
|
||||
self.reply_channel = self.factory.make_send_channel()
|
||||
# Tell factory we're that channel's client
|
||||
self.last_keepalive = time.time()
|
||||
self.factory.reply_protocols[self.reply_channel] = self
|
||||
self._got_response_start = False
|
||||
# Easy server link
|
||||
self.server = self.channel.factory.server
|
||||
self.consumer_channel = None
|
||||
self._response_started = False
|
||||
self.server.add_protocol(self)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
### Twisted progress callbacks
|
||||
|
||||
def process(self):
|
||||
try:
|
||||
self.request_start = time.time()
|
||||
|
@ -79,15 +80,14 @@ class WebRequest(http.Request):
|
|||
else:
|
||||
self.client_addr = None
|
||||
self.server_addr = None
|
||||
|
||||
if self.factory.proxy_forwarded_address_header:
|
||||
# See if we need to get the address from a proxy header instead
|
||||
if self.server.proxy_forwarded_address_header:
|
||||
self.client_addr = parse_x_forwarded_for(
|
||||
self.requestHeaders,
|
||||
self.factory.proxy_forwarded_address_header,
|
||||
self.factory.proxy_forwarded_port_header,
|
||||
self.server.proxy_forwarded_address_header,
|
||||
self.server.proxy_forwarded_port_header,
|
||||
self.client_addr
|
||||
)
|
||||
|
||||
# Check for unicodeish path (or it'll crash when trying to parse)
|
||||
try:
|
||||
self.path.decode("ascii")
|
||||
|
@ -102,7 +102,7 @@ class WebRequest(http.Request):
|
|||
# Is it WebSocket? IS IT?!
|
||||
if upgrade_header and upgrade_header.lower() == b"websocket":
|
||||
# Make WebSocket protocol to hand off to
|
||||
protocol = self.factory.ws_factory.buildProtocol(self.transport.getPeer())
|
||||
protocol = self.server.ws_factory.buildProtocol(self.transport.getPeer())
|
||||
if not protocol:
|
||||
# If protocol creation fails, we signal "internal server error"
|
||||
self.setResponseCode(500)
|
||||
|
@ -111,7 +111,6 @@ class WebRequest(http.Request):
|
|||
# Give it the raw query string
|
||||
protocol._raw_query_string = self.query_string
|
||||
# Port across transport
|
||||
protocol.set_main_factory(self.factory)
|
||||
transport, self.transport = self.transport, None
|
||||
if isinstance(transport, ProtocolWrapper):
|
||||
# i.e. TLS is a wrapping protocol
|
||||
|
@ -127,12 +126,7 @@ class WebRequest(http.Request):
|
|||
data += self.content.read()
|
||||
protocol.dataReceived(data)
|
||||
# Remove our HTTP reply channel association
|
||||
if hasattr(protocol, "reply_channel"):
|
||||
logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
|
||||
else:
|
||||
logger.debug("Connection %s did not get successful WS handshake.", self.reply_channel)
|
||||
del self.factory.reply_protocols[self.reply_channel]
|
||||
self.reply_channel = None
|
||||
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
|
||||
# Resume the producer so we keep getting data, if it's available as a method
|
||||
# 17.1 version
|
||||
if hasattr(self.channel, "_networkProducer"):
|
||||
|
@ -143,9 +137,11 @@ class WebRequest(http.Request):
|
|||
|
||||
# Boring old HTTP.
|
||||
else:
|
||||
# Create consumer to handle this connection
|
||||
self.consumer_channel = self.server.create_consumer(self)
|
||||
# Sanitize and decode headers, potentially extracting root path
|
||||
self.clean_headers = []
|
||||
self.root_path = self.factory.root_path
|
||||
self.root_path = self.server.root_path
|
||||
for name, values in self.requestHeaders.getAllRawHeaders():
|
||||
# Prevent CVE-2015-0219
|
||||
if b"_" in name:
|
||||
|
@ -155,12 +151,13 @@ class WebRequest(http.Request):
|
|||
self.root_path = self.unquote(value)
|
||||
else:
|
||||
self.clean_headers.append((name.lower(), value))
|
||||
logger.debug("HTTP %s request for %s", self.method, self.reply_channel)
|
||||
logger.debug("HTTP %s request for %s", self.method, self.consumer_channel)
|
||||
self.content.seek(0, 0)
|
||||
# Send message
|
||||
try:
|
||||
self.factory.channel_layer.send("http.request", {
|
||||
"reply_channel": self.reply_channel,
|
||||
# Run consumer
|
||||
self.server.handle_message(
|
||||
self.consumer_channel,
|
||||
{
|
||||
"type": "http.request",
|
||||
# TODO: Correctly say if it's 1.1 or 1.0
|
||||
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
|
||||
"method": self.method.decode("ascii"),
|
||||
|
@ -172,14 +169,90 @@ class WebRequest(http.Request):
|
|||
"body": self.content.read(),
|
||||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
})
|
||||
except self.factory.channel_layer.ChannelFull:
|
||||
# Channel is too full; reject request with 503
|
||||
self.basic_error(503, b"Service Unavailable", "Request queue full.")
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
self.basic_error(500, b"Internal Server Error", "HTTP processing error")
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Cleans up reply channel on close.
|
||||
"""
|
||||
if self.consumer_channel:
|
||||
self.send_disconnect()
|
||||
logger.debug("HTTP disconnect for %s", self.client_addr)
|
||||
http.Request.connectionLost(self, reason)
|
||||
self.server.discard_protocol(self)
|
||||
|
||||
def finish(self):
|
||||
"""
|
||||
Cleans up reply channel on close.
|
||||
"""
|
||||
if self.consumer_channel:
|
||||
self.send_disconnect()
|
||||
logger.debug("HTTP close for %s", self.client_addr)
|
||||
http.Request.finish(self)
|
||||
self.server.discard_protocol(self)
|
||||
|
||||
### Server reply callbacks
|
||||
|
||||
def handle_reply(self, message):
|
||||
"""
|
||||
Handles a reply from the client
|
||||
"""
|
||||
if message['type'] == "http.response":
|
||||
if self._response_started:
|
||||
raise ValueError("HTTP response has already been started")
|
||||
self._response_started = True
|
||||
if 'status' not in message:
|
||||
raise ValueError("Specifying a status code is required for a Response message.")
|
||||
# Set HTTP status code
|
||||
self.setResponseCode(message['status'])
|
||||
# Write headers
|
||||
for header, value in message.get("headers", {}):
|
||||
# Shim code from old ASGI version, can be removed after a while
|
||||
if isinstance(header, six.text_type):
|
||||
header = header.encode("latin1")
|
||||
self.responseHeaders.addRawHeader(header, value)
|
||||
logger.debug("HTTP %s response started for %s", message['status'], self.client_addr)
|
||||
elif message['type'] == "http.response.content":
|
||||
if not self._response_started:
|
||||
raise ValueError("HTTP response has not yet been started")
|
||||
else:
|
||||
raise ValueError("Cannot handle message type %s!" % message['type'])
|
||||
|
||||
# Write out body
|
||||
http.Request.write(self, message.get('content', b''))
|
||||
|
||||
# End if there's no more content
|
||||
if not message.get("more_content", False):
|
||||
self.finish()
|
||||
logger.debug("HTTP response complete for %s", self.client_addr)
|
||||
try:
|
||||
self.server.log_action("http", "complete", {
|
||||
"path": self.uri.decode("ascii"),
|
||||
"status": self.code,
|
||||
"method": self.method.decode("ascii"),
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
"time_taken": self.duration(),
|
||||
"size": self.sentLength,
|
||||
})
|
||||
except Exception as e:
|
||||
logging.error(traceback.format_exc())
|
||||
else:
|
||||
logger.debug("HTTP response chunk for %s", self.client_addr)
|
||||
|
||||
def check_timeouts(self):
|
||||
"""
|
||||
Called periodically to see if we should timeout something
|
||||
"""
|
||||
# Web timeout checking
|
||||
if self.duration() > self.server.http_timeout:
|
||||
self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.")
|
||||
|
||||
### Utility functions
|
||||
|
||||
@classmethod
|
||||
def unquote(cls, value, plus_as_space=False):
|
||||
"""
|
||||
|
@ -198,81 +271,17 @@ class WebRequest(http.Request):
|
|||
|
||||
def send_disconnect(self):
|
||||
"""
|
||||
Sends a disconnect message on the http.disconnect channel.
|
||||
Sends a http.disconnect message.
|
||||
Useful only really for long-polling.
|
||||
"""
|
||||
# If we don't yet have a path, then don't send as we never opened.
|
||||
if self.path:
|
||||
try:
|
||||
self.factory.channel_layer.send("http.disconnect", {
|
||||
"reply_channel": self.reply_channel,
|
||||
"path": self.unquote(self.path),
|
||||
})
|
||||
except self.factory.channel_layer.ChannelFull:
|
||||
pass
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Cleans up reply channel on close.
|
||||
"""
|
||||
if self.reply_channel and self.reply_channel in self.channel.factory.reply_protocols:
|
||||
self.send_disconnect()
|
||||
del self.channel.factory.reply_protocols[self.reply_channel]
|
||||
logger.debug("HTTP disconnect for %s", self.reply_channel)
|
||||
http.Request.connectionLost(self, reason)
|
||||
|
||||
def finish(self):
|
||||
"""
|
||||
Cleans up reply channel on close.
|
||||
"""
|
||||
if self.reply_channel and self.reply_channel in self.channel.factory.reply_protocols:
|
||||
self.send_disconnect()
|
||||
del self.channel.factory.reply_protocols[self.reply_channel]
|
||||
logger.debug("HTTP close for %s", self.reply_channel)
|
||||
http.Request.finish(self)
|
||||
|
||||
def serverResponse(self, message):
|
||||
"""
|
||||
Writes a received HTTP response back out to the transport.
|
||||
"""
|
||||
if not self._got_response_start:
|
||||
self._got_response_start = True
|
||||
if 'status' not in message:
|
||||
raise ValueError("Specifying a status code is required for a Response message.")
|
||||
|
||||
# Set HTTP status code
|
||||
self.setResponseCode(message['status'])
|
||||
# Write headers
|
||||
for header, value in message.get("headers", {}):
|
||||
# Shim code from old ASGI version, can be removed after a while
|
||||
if isinstance(header, six.text_type):
|
||||
header = header.encode("latin1")
|
||||
self.responseHeaders.addRawHeader(header, value)
|
||||
logger.debug("HTTP %s response started for %s", message['status'], self.reply_channel)
|
||||
else:
|
||||
if 'status' in message:
|
||||
raise ValueError("Got multiple Response messages for %s!" % self.reply_channel)
|
||||
|
||||
# Write out body
|
||||
http.Request.write(self, message.get('content', b''))
|
||||
|
||||
# End if there's no more content
|
||||
if not message.get("more_content", False):
|
||||
self.finish()
|
||||
logger.debug("HTTP response complete for %s", self.reply_channel)
|
||||
try:
|
||||
self.factory.log_action("http", "complete", {
|
||||
"path": self.uri.decode("ascii"),
|
||||
"status": self.code,
|
||||
"method": self.method.decode("ascii"),
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
"time_taken": self.duration(),
|
||||
"size": self.sentLength,
|
||||
})
|
||||
except Exception as e:
|
||||
logging.error(traceback.format_exc())
|
||||
else:
|
||||
logger.debug("HTTP response chunk for %s", self.reply_channel)
|
||||
self.server.handle_message(
|
||||
self.consumer_channel,
|
||||
{
|
||||
"type": "http.disconnect",
|
||||
},
|
||||
)
|
||||
|
||||
def duration(self):
|
||||
"""
|
||||
|
@ -298,6 +307,12 @@ class WebRequest(http.Request):
|
|||
}).encode("utf8"),
|
||||
})
|
||||
|
||||
def __hash__(self):
|
||||
return hash(id(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
return id(self) == id(other)
|
||||
|
||||
|
||||
@implementer(IProtocolNegotiationFactory)
|
||||
class HTTPFactory(http.HTTPFactory):
|
||||
|
@ -308,30 +323,9 @@ class HTTPFactory(http.HTTPFactory):
|
|||
routed appropriately.
|
||||
"""
|
||||
|
||||
def __init__(self, channel_layer, action_logger=None, send_channel=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30, proxy_forwarded_address_header=None, proxy_forwarded_port_header=None, websocket_handshake_timeout=5):
|
||||
def __init__(self, server):
|
||||
http.HTTPFactory.__init__(self)
|
||||
self.channel_layer = channel_layer
|
||||
self.action_logger = action_logger
|
||||
self.send_channel = send_channel
|
||||
assert self.send_channel is not None
|
||||
self.timeout = timeout
|
||||
self.websocket_timeout = websocket_timeout
|
||||
self.websocket_connect_timeout = websocket_connect_timeout
|
||||
self.ping_interval = ping_interval
|
||||
self.proxy_forwarded_address_header = proxy_forwarded_address_header
|
||||
self.proxy_forwarded_port_header = proxy_forwarded_port_header
|
||||
# We track all sub-protocols for response channel mapping
|
||||
self.reply_protocols = {}
|
||||
# Make a factory for WebSocket protocols
|
||||
self.ws_factory = WebSocketFactory(self, protocols=ws_protocols, server='Daphne')
|
||||
self.ws_factory.setProtocolOptions(
|
||||
autoPingTimeout=ping_timeout,
|
||||
allowNullOrigin=True,
|
||||
openHandshakeTimeout=websocket_handshake_timeout
|
||||
)
|
||||
self.ws_factory.protocol = WebSocketProtocol
|
||||
self.ws_factory.reply_protocols = self.reply_protocols
|
||||
self.root_path = root_path
|
||||
self.server = server
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
"""
|
||||
|
@ -353,9 +347,6 @@ class HTTPFactory(http.HTTPFactory):
|
|||
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
|
||||
return self.send_channel + protocol_id
|
||||
|
||||
def reply_channels(self):
|
||||
return self.reply_protocols.keys()
|
||||
|
||||
def dispatch_reply(self, channel, message):
|
||||
# If we don't know about the channel, ignore it (likely a channel we
|
||||
# used to have that's now in a group).
|
||||
|
@ -408,13 +399,6 @@ class HTTPFactory(http.HTTPFactory):
|
|||
else:
|
||||
raise ValueError("Unknown protocol class")
|
||||
|
||||
def log_action(self, protocol, action, details):
|
||||
"""
|
||||
Dispatches to any registered action logger, if there is one.
|
||||
"""
|
||||
if self.action_logger:
|
||||
self.action_logger(protocol, action, details)
|
||||
|
||||
def check_timeouts(self):
|
||||
"""
|
||||
Runs through all HTTP protocol instances and times them out if they've
|
||||
|
|
192
daphne/server.py
192
daphne/server.py
|
@ -1,5 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
|
@ -9,8 +10,10 @@ from twisted.internet import reactor, defer
|
|||
from twisted.internet.endpoints import serverFromString
|
||||
from twisted.logger import globalLogBeginner, STDLibLogObserver
|
||||
from twisted.web import http
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
|
||||
from .http_protocol import HTTPFactory
|
||||
from .ws_protocol import WebSocketFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -20,11 +23,8 @@ class Server(object):
|
|||
def __init__(
|
||||
self,
|
||||
channel_layer,
|
||||
host=None,
|
||||
port=None,
|
||||
consumer,
|
||||
endpoints=None,
|
||||
unix_socket=None,
|
||||
file_descriptor=None,
|
||||
signal_handlers=True,
|
||||
action_logger=None,
|
||||
http_timeout=120,
|
||||
|
@ -36,28 +36,14 @@ class Server(object):
|
|||
root_path="",
|
||||
proxy_forwarded_address_header=None,
|
||||
proxy_forwarded_port_header=None,
|
||||
force_sync=False,
|
||||
verbosity=1,
|
||||
websocket_handshake_timeout=5
|
||||
):
|
||||
self.channel_layer = channel_layer
|
||||
self.consumer = consumer
|
||||
self.endpoints = endpoints or []
|
||||
|
||||
if any([host, port, unix_socket, file_descriptor]):
|
||||
warnings.warn('''
|
||||
The host/port/unix_socket/file_descriptor keyword arguments to %s are deprecated.
|
||||
''' % self.__class__.__name__, DeprecationWarning)
|
||||
# build endpoint description strings from deprecated kwargs
|
||||
self.endpoints = sorted(self.endpoints + build_endpoint_description_strings(
|
||||
host=host,
|
||||
port=port,
|
||||
unix_socket=unix_socket,
|
||||
file_descriptor=file_descriptor
|
||||
))
|
||||
|
||||
if len(self.endpoints) == 0:
|
||||
raise UserWarning("No endpoints. This server will not listen on anything.")
|
||||
|
||||
self.listeners = []
|
||||
self.signal_handlers = signal_handlers
|
||||
self.action_logger = action_logger
|
||||
|
@ -73,29 +59,27 @@ class Server(object):
|
|||
self.websocket_handshake_timeout = websocket_handshake_timeout
|
||||
self.ws_protocols = ws_protocols
|
||||
self.root_path = root_path
|
||||
self.force_sync = force_sync
|
||||
self.verbosity = verbosity
|
||||
|
||||
def run(self):
|
||||
# Make the thread pool to run consumers in
|
||||
# TODO: Configurable numbers of threads
|
||||
self.pool = ThreadPool(name="consumers")
|
||||
# Make the mapping of consumer instances to consumer channels
|
||||
self.consumer_instances = {}
|
||||
# A set of current Twisted protocol instances to manage
|
||||
self.protocols = set()
|
||||
# Create process-local channel prefixes
|
||||
# TODO: Can we guarantee non-collision better?
|
||||
process_id = "".join(random.choice(string.ascii_letters) for i in range(10))
|
||||
self.send_channel = "daphne.response.%s!" % process_id
|
||||
self.consumer_channel_prefix = "daphne.%s!" % process_id
|
||||
# Make the factory
|
||||
self.factory = HTTPFactory(
|
||||
self.channel_layer,
|
||||
action_logger=self.action_logger,
|
||||
send_channel=self.send_channel,
|
||||
timeout=self.http_timeout,
|
||||
websocket_timeout=self.websocket_timeout,
|
||||
websocket_connect_timeout=self.websocket_connect_timeout,
|
||||
ping_interval=self.ping_interval,
|
||||
ping_timeout=self.ping_timeout,
|
||||
ws_protocols=self.ws_protocols,
|
||||
root_path=self.root_path,
|
||||
proxy_forwarded_address_header=self.proxy_forwarded_address_header,
|
||||
proxy_forwarded_port_header=self.proxy_forwarded_port_header,
|
||||
websocket_handshake_timeout=self.websocket_handshake_timeout
|
||||
self.http_factory = HTTPFactory(self)
|
||||
self.ws_factory = WebSocketFactory(self, protocols=self.ws_protocols, server='Daphne')
|
||||
self.ws_factory.setProtocolOptions(
|
||||
autoPingTimeout=self.ping_timeout,
|
||||
allowNullOrigin=True,
|
||||
openHandshakeTimeout=self.websocket_handshake_timeout
|
||||
)
|
||||
if self.verbosity <= 1:
|
||||
# Redirect the Twisted log to nowhere
|
||||
|
@ -109,28 +93,78 @@ class Server(object):
|
|||
else:
|
||||
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
|
||||
|
||||
if "twisted" in self.channel_layer.extensions and not self.force_sync:
|
||||
logger.info("Using native Twisted mode on channel layer")
|
||||
reactor.callLater(0, self.backend_reader_twisted)
|
||||
else:
|
||||
logger.info("Using busy-loop synchronous mode on channel layer")
|
||||
reactor.callLater(0, self.backend_reader_sync)
|
||||
# Kick off the various background loops
|
||||
reactor.callLater(0, self.backend_reader)
|
||||
reactor.callLater(2, self.timeout_checker)
|
||||
|
||||
for socket_description in self.endpoints:
|
||||
logger.info("Listening on endpoint %s" % socket_description)
|
||||
# Twisted requires str on python2 (not unicode) and str on python3 (not bytes)
|
||||
ep = serverFromString(reactor, str(socket_description))
|
||||
self.listeners.append(ep.listen(self.factory))
|
||||
self.listeners.append(ep.listen(self.http_factory))
|
||||
|
||||
self.pool.start()
|
||||
reactor.addSystemEventTrigger("before", "shutdown", self.pool.stop)
|
||||
reactor.run(installSignalHandlers=self.signal_handlers)
|
||||
|
||||
def backend_reader_sync(self):
|
||||
### Protocol handling
|
||||
|
||||
def add_protocol(self, protocol):
|
||||
if protocol in self.protocols:
|
||||
raise RuntimeError("Protocol %r was added to main list twice!" % protocol)
|
||||
self.protocols.add(protocol)
|
||||
|
||||
def discard_protocol(self, protocol):
|
||||
self.protocols.discard(protocol)
|
||||
|
||||
### Internal event/message handling
|
||||
|
||||
def create_consumer(self, protocol):
|
||||
"""
|
||||
Creates a new consumer instance that fronts a Protocol instance
|
||||
for one of our supported protocols. Pass it the protocol,
|
||||
and it will work out the type, supply appropriate callables, and
|
||||
put it into the server's consumer pool.
|
||||
|
||||
It returns the consumer channel name, which is how you should refer
|
||||
to the consumer instance.
|
||||
"""
|
||||
# Make sure the protocol defines a consumer type
|
||||
assert protocol.consumer_type is not None
|
||||
# Make it a consumer channel name
|
||||
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
|
||||
consumer_channel = self.consumer_channel_prefix + protocol_id
|
||||
# Make an instance of the consumer
|
||||
consumer_instance = self.consumer(
|
||||
type=protocol.consumer_type,
|
||||
reply=lambda message: self.handle_reply(protocol, message),
|
||||
channel_layer=self.channel_layer,
|
||||
consumer_channel=consumer_channel,
|
||||
)
|
||||
# Assign it by channel and return it
|
||||
self.consumer_instances[consumer_channel] = consumer_instance
|
||||
return consumer_channel
|
||||
|
||||
def handle_message(self, consumer_channel, message):
|
||||
"""
|
||||
Schedules the application instance to handle the given message.
|
||||
"""
|
||||
self.pool.callInThread(self.consumer_instances[consumer_channel], message)
|
||||
|
||||
def handle_reply(self, protocol, message):
|
||||
"""
|
||||
Schedules the reply to be handled by the protocol in the main thread
|
||||
"""
|
||||
reactor.callFromThread(reactor.callLater, 0, protocol.handle_reply, message)
|
||||
|
||||
### External event/message handling
|
||||
|
||||
def backend_reader(self):
|
||||
"""
|
||||
Runs as an-often-as-possible task with the reactor, unless there was
|
||||
no result previously in which case we add a small delay.
|
||||
"""
|
||||
channels = [self.send_channel]
|
||||
channels = [self.consumer_channel_prefix]
|
||||
delay = 0
|
||||
# Quit if reactor is stopping
|
||||
if not reactor.running:
|
||||
|
@ -147,75 +181,29 @@ class Server(object):
|
|||
if channel:
|
||||
# Deal with the message
|
||||
try:
|
||||
self.factory.dispatch_reply(channel, message)
|
||||
self.handle_message(channel, message)
|
||||
except Exception as e:
|
||||
logger.error("HTTP/WS send decode error: %s" % e)
|
||||
logger.error("Error handling external message: %s" % e)
|
||||
else:
|
||||
# If there's no messages, idle a little bit.
|
||||
delay = 0.05
|
||||
# We can't loop inside here as this is synchronous code.
|
||||
reactor.callLater(delay, self.backend_reader_sync)
|
||||
reactor.callLater(delay, self.backend_reader)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def backend_reader_twisted(self):
|
||||
"""
|
||||
Runs as an-often-as-possible task with the reactor, unless there was
|
||||
no result previously in which case we add a small delay.
|
||||
"""
|
||||
channels = [self.send_channel]
|
||||
while True:
|
||||
if not reactor.running:
|
||||
logging.debug("Backend reader quitting due to reactor stop")
|
||||
return
|
||||
try:
|
||||
channel, message = yield self.channel_layer.receive_twisted(channels)
|
||||
except Exception as e:
|
||||
logger.error('Error trying to receive messages: %s' % e)
|
||||
yield self.sleep(5.00)
|
||||
else:
|
||||
# Deal with the message
|
||||
if channel:
|
||||
try:
|
||||
self.factory.dispatch_reply(channel, message)
|
||||
except Exception as e:
|
||||
logger.error("HTTP/WS send decode error: %s" % e)
|
||||
|
||||
def sleep(self, delay):
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(delay, d.callback, None)
|
||||
return d
|
||||
### Utility
|
||||
|
||||
def timeout_checker(self):
|
||||
"""
|
||||
Called periodically to enforce timeout rules on all connections.
|
||||
Also checks pings at the same time.
|
||||
"""
|
||||
self.factory.check_timeouts()
|
||||
for protocol in self.protocols:
|
||||
protocol.check_timeouts()
|
||||
reactor.callLater(2, self.timeout_checker)
|
||||
|
||||
|
||||
def build_endpoint_description_strings(
|
||||
host=None,
|
||||
port=None,
|
||||
unix_socket=None,
|
||||
file_descriptor=None
|
||||
):
|
||||
"""
|
||||
Build a list of twisted endpoint description strings that the server will listen on.
|
||||
This is to streamline the generation of twisted endpoint description strings from easier
|
||||
to use command line args such as host, port, unix sockets etc.
|
||||
"""
|
||||
socket_descriptions = []
|
||||
if host and port:
|
||||
host = host.strip('[]').replace(':', '\:')
|
||||
socket_descriptions.append('tcp:port=%d:interface=%s' % (int(port), host))
|
||||
elif any([host, port]):
|
||||
raise ValueError('TCP binding requires both port and host kwargs.')
|
||||
|
||||
if unix_socket:
|
||||
socket_descriptions.append('unix:%s' % unix_socket)
|
||||
|
||||
if file_descriptor is not None:
|
||||
socket_descriptions.append('fd:fileno=%d' % int(file_descriptor))
|
||||
|
||||
return socket_descriptions
|
||||
def log_action(self, protocol, action, details):
|
||||
"""
|
||||
Dispatches to any registered action logger, if there is one.
|
||||
"""
|
||||
if self.action_logger:
|
||||
self.action_logger(protocol, action, details)
|
||||
|
|
|
@ -1,6 +1,20 @@
|
|||
import sys
|
||||
import importlib
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
|
||||
def import_by_path(path):
|
||||
"""
|
||||
Given a dotted/colon path, like project.module:ClassName.callable,
|
||||
returns the object at the end of the path.
|
||||
"""
|
||||
module_path, object_path = path.split(":", 1)
|
||||
target = importlib.import_module(module_path)
|
||||
for bit in object_path.split("."):
|
||||
target = getattr(target, bit)
|
||||
return target
|
||||
|
||||
|
||||
def header_value(headers, header_name):
|
||||
value = headers[header_name]
|
||||
if isinstance(value, list):
|
||||
|
|
|
@ -23,9 +23,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
# If we should send no more messages (e.g. we error-closed the socket)
|
||||
muted = False
|
||||
|
||||
def set_main_factory(self, main_factory):
|
||||
self.main_factory = main_factory
|
||||
self.channel_layer = self.main_factory.channel_layer
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(WebSocketProtocol, self).__init__(*args, **kwargs)
|
||||
self.server = self.factory.server_class
|
||||
|
||||
def onConnect(self, request):
|
||||
self.request = request
|
||||
|
@ -239,19 +239,29 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
"""
|
||||
return time.time() - self.socket_opened
|
||||
|
||||
def check_ping(self):
|
||||
def check_timeouts(self):
|
||||
"""
|
||||
Checks to see if we should send a keepalive ping/deny socket connection
|
||||
Called periodically to see if we should timeout something
|
||||
"""
|
||||
# Web timeout checking
|
||||
if self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0:
|
||||
self.serverClose()
|
||||
# Ping check
|
||||
# If we're still connecting, deny the connection
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
if self.duration() > self.main_factory.websocket_connect_timeout:
|
||||
if self.duration() > self.server.websocket_connect_timeout:
|
||||
self.serverReject()
|
||||
elif self.state == self.STATE_OPEN:
|
||||
if (time.time() - self.last_data) > self.main_factory.ping_interval:
|
||||
if (time.time() - self.last_data) > self.server.ping_interval:
|
||||
self._sendAutoPing()
|
||||
self.last_data = time.time()
|
||||
|
||||
def __hash__(self):
|
||||
return hash(id(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
return id(self) == id(other)
|
||||
|
||||
|
||||
class WebSocketFactory(WebSocketServerFactory):
|
||||
"""
|
||||
|
@ -260,9 +270,8 @@ class WebSocketFactory(WebSocketServerFactory):
|
|||
to get reply ID info.
|
||||
"""
|
||||
|
||||
def __init__(self, main_factory, *args, **kwargs):
|
||||
self.main_factory = main_factory
|
||||
WebSocketServerFactory.__init__(self, *args, **kwargs)
|
||||
protocol = WebSocketProtocol
|
||||
|
||||
def log_action(self, *args, **kwargs):
|
||||
self.main_factory.log_action(*args, **kwargs)
|
||||
def __init__(self, server_class, *args, **kwargs):
|
||||
self.server_class = server_class
|
||||
WebSocketServerFactory.__init__(self, *args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue
Block a user