mirror of
https://github.com/django/daphne.git
synced 2024-11-22 07:56:34 +03:00
Trying out asyncio based interface
This commit is contained in:
parent
a656c9f4c6
commit
01f174bf26
|
@ -131,12 +131,8 @@ class CommandLineInterface(object):
|
||||||
default=2,
|
default=2,
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
'channel_layer',
|
'application',
|
||||||
help='The ASGI channel layer instance to use as path.to.module:instance.path',
|
help='The application to dispatch to as path.to.module:instance.path',
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
|
||||||
'consumer',
|
|
||||||
help='The consumer to dispatch to as path.to.module:instance.path',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.server = None
|
self.server = None
|
||||||
|
@ -161,6 +157,7 @@ class CommandLineInterface(object):
|
||||||
0: logging.WARN,
|
0: logging.WARN,
|
||||||
1: logging.INFO,
|
1: logging.INFO,
|
||||||
2: logging.DEBUG,
|
2: logging.DEBUG,
|
||||||
|
3: logging.DEBUG, # Also turns on asyncio debug
|
||||||
}[args.verbosity],
|
}[args.verbosity],
|
||||||
format="%(asctime)-15s %(levelname)-8s %(message)s",
|
format="%(asctime)-15s %(levelname)-8s %(message)s",
|
||||||
)
|
)
|
||||||
|
@ -173,10 +170,9 @@ class CommandLineInterface(object):
|
||||||
access_log_stream = open(args.access_log, "a", 1)
|
access_log_stream = open(args.access_log, "a", 1)
|
||||||
elif args.verbosity >= 1:
|
elif args.verbosity >= 1:
|
||||||
access_log_stream = sys.stdout
|
access_log_stream = sys.stdout
|
||||||
# Import application and channel_layer
|
# Import application
|
||||||
sys.path.insert(0, ".")
|
sys.path.insert(0, ".")
|
||||||
channel_layer = import_by_path(args.channel_layer)
|
application = import_by_path(args.application)
|
||||||
consumer = import_by_path(args.consumer)
|
|
||||||
# Set up port/host bindings
|
# Set up port/host bindings
|
||||||
if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]):
|
if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]):
|
||||||
# no advanced binding options passed, patch in defaults
|
# no advanced binding options passed, patch in defaults
|
||||||
|
@ -198,12 +194,11 @@ class CommandLineInterface(object):
|
||||||
)
|
)
|
||||||
# Start the server
|
# Start the server
|
||||||
logger.info(
|
logger.info(
|
||||||
'Starting server at %s, channel layer %s.' %
|
'Starting server at %s' %
|
||||||
(', '.join(endpoints), args.channel_layer)
|
(', '.join(endpoints), )
|
||||||
)
|
)
|
||||||
self.server = Server(
|
self.server = Server(
|
||||||
channel_layer=channel_layer,
|
application=application,
|
||||||
consumer=consumer,
|
|
||||||
endpoints=endpoints,
|
endpoints=endpoints,
|
||||||
http_timeout=args.http_timeout,
|
http_timeout=args.http_timeout,
|
||||||
ping_interval=args.ping_interval,
|
ping_interval=args.ping_interval,
|
||||||
|
|
|
@ -10,6 +10,7 @@ import traceback
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from six.moves.urllib_parse import unquote, unquote_plus
|
from six.moves.urllib_parse import unquote, unquote_plus
|
||||||
|
from twisted.internet.defer import ensureDeferred
|
||||||
from twisted.internet.interfaces import IProtocolNegotiationFactory
|
from twisted.internet.interfaces import IProtocolNegotiationFactory
|
||||||
from twisted.protocols.policies import ProtocolWrapper
|
from twisted.protocols.policies import ProtocolWrapper
|
||||||
from twisted.web import http
|
from twisted.web import http
|
||||||
|
@ -29,7 +30,7 @@ class WebRequest(http.Request):
|
||||||
GET and POST out.
|
GET and POST out.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
consumer_type = "http"
|
application_type = "http"
|
||||||
|
|
||||||
error_template = """
|
error_template = """
|
||||||
<html>
|
<html>
|
||||||
|
@ -55,7 +56,7 @@ class WebRequest(http.Request):
|
||||||
http.Request.__init__(self, *args, **kwargs)
|
http.Request.__init__(self, *args, **kwargs)
|
||||||
# Easy server link
|
# Easy server link
|
||||||
self.server = self.channel.factory.server
|
self.server = self.channel.factory.server
|
||||||
self.consumer_channel = None
|
self.application_queue = None
|
||||||
self._response_started = False
|
self._response_started = False
|
||||||
self.server.add_protocol(self)
|
self.server.add_protocol(self)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -137,8 +138,8 @@ class WebRequest(http.Request):
|
||||||
|
|
||||||
# Boring old HTTP.
|
# Boring old HTTP.
|
||||||
else:
|
else:
|
||||||
# Create consumer to handle this connection
|
# Create application to handle this connection
|
||||||
self.consumer_channel = self.server.create_consumer(self)
|
self.application_queue = self.server.create_application(self)
|
||||||
# Sanitize and decode headers, potentially extracting root path
|
# Sanitize and decode headers, potentially extracting root path
|
||||||
self.clean_headers = []
|
self.clean_headers = []
|
||||||
self.root_path = self.server.root_path
|
self.root_path = self.server.root_path
|
||||||
|
@ -151,11 +152,10 @@ class WebRequest(http.Request):
|
||||||
self.root_path = self.unquote(value)
|
self.root_path = self.unquote(value)
|
||||||
else:
|
else:
|
||||||
self.clean_headers.append((name.lower(), value))
|
self.clean_headers.append((name.lower(), value))
|
||||||
logger.debug("HTTP %s request for %s", self.method, self.consumer_channel)
|
logger.debug("HTTP %s request for %s", self.method, self.client_addr)
|
||||||
self.content.seek(0, 0)
|
self.content.seek(0, 0)
|
||||||
# Run consumer
|
# Run application against request
|
||||||
self.server.handle_message(
|
self.application_queue.put_nowait(
|
||||||
self.consumer_channel,
|
|
||||||
{
|
{
|
||||||
"type": "http.request",
|
"type": "http.request",
|
||||||
# TODO: Correctly say if it's 1.1 or 1.0
|
# TODO: Correctly say if it's 1.1 or 1.0
|
||||||
|
@ -173,13 +173,13 @@ class WebRequest(http.Request):
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
self.basic_error(500, b"Internal Server Error", "HTTP processing error")
|
self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error")
|
||||||
|
|
||||||
def connectionLost(self, reason):
|
def connectionLost(self, reason):
|
||||||
"""
|
"""
|
||||||
Cleans up reply channel on close.
|
Cleans up reply channel on close.
|
||||||
"""
|
"""
|
||||||
if self.consumer_channel:
|
if self.application_queue:
|
||||||
self.send_disconnect()
|
self.send_disconnect()
|
||||||
logger.debug("HTTP disconnect for %s", self.client_addr)
|
logger.debug("HTTP disconnect for %s", self.client_addr)
|
||||||
http.Request.connectionLost(self, reason)
|
http.Request.connectionLost(self, reason)
|
||||||
|
@ -189,7 +189,7 @@ class WebRequest(http.Request):
|
||||||
"""
|
"""
|
||||||
Cleans up reply channel on close.
|
Cleans up reply channel on close.
|
||||||
"""
|
"""
|
||||||
if self.consumer_channel:
|
if self.application_queue:
|
||||||
self.send_disconnect()
|
self.send_disconnect()
|
||||||
logger.debug("HTTP close for %s", self.client_addr)
|
logger.debug("HTTP close for %s", self.client_addr)
|
||||||
http.Request.finish(self)
|
http.Request.finish(self)
|
||||||
|
@ -201,6 +201,8 @@ class WebRequest(http.Request):
|
||||||
"""
|
"""
|
||||||
Handles a reply from the client
|
Handles a reply from the client
|
||||||
"""
|
"""
|
||||||
|
if "type" not in message:
|
||||||
|
raise ValueError("Message has no type defined")
|
||||||
if message['type'] == "http.response":
|
if message['type'] == "http.response":
|
||||||
if self._response_started:
|
if self._response_started:
|
||||||
raise ValueError("HTTP response has already been started")
|
raise ValueError("HTTP response has already been started")
|
||||||
|
@ -216,9 +218,9 @@ class WebRequest(http.Request):
|
||||||
header = header.encode("latin1")
|
header = header.encode("latin1")
|
||||||
self.responseHeaders.addRawHeader(header, value)
|
self.responseHeaders.addRawHeader(header, value)
|
||||||
logger.debug("HTTP %s response started for %s", message['status'], self.client_addr)
|
logger.debug("HTTP %s response started for %s", message['status'], self.client_addr)
|
||||||
elif message['type'] == "http.response.content":
|
elif message['type'] == "http.response.hunk":
|
||||||
if not self._response_started:
|
if not self._response_started:
|
||||||
raise ValueError("HTTP response has not yet been started")
|
raise ValueError("HTTP response has not yet been started but got %s" % message['type'])
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cannot handle message type %s!" % message['type'])
|
raise ValueError("Cannot handle message type %s!" % message['type'])
|
||||||
|
|
||||||
|
@ -243,6 +245,12 @@ class WebRequest(http.Request):
|
||||||
else:
|
else:
|
||||||
logger.debug("HTTP response chunk for %s", self.client_addr)
|
logger.debug("HTTP response chunk for %s", self.client_addr)
|
||||||
|
|
||||||
|
def handle_exception(self, exception):
|
||||||
|
"""
|
||||||
|
Called by the server when our application tracebacks
|
||||||
|
"""
|
||||||
|
self.basic_error(500, b"Internal Server Error", "Exception inside application.")
|
||||||
|
|
||||||
def check_timeouts(self):
|
def check_timeouts(self):
|
||||||
"""
|
"""
|
||||||
Called periodically to see if we should timeout something
|
Called periodically to see if we should timeout something
|
||||||
|
@ -276,8 +284,7 @@ class WebRequest(http.Request):
|
||||||
"""
|
"""
|
||||||
# If we don't yet have a path, then don't send as we never opened.
|
# If we don't yet have a path, then don't send as we never opened.
|
||||||
if self.path:
|
if self.path:
|
||||||
self.server.handle_message(
|
self.application_queue.put_nowait(
|
||||||
self.consumer_channel,
|
|
||||||
{
|
{
|
||||||
"type": "http.disconnect",
|
"type": "http.disconnect",
|
||||||
},
|
},
|
||||||
|
@ -295,7 +302,8 @@ class WebRequest(http.Request):
|
||||||
"""
|
"""
|
||||||
Responds with a server-level error page (very basic)
|
Responds with a server-level error page (very basic)
|
||||||
"""
|
"""
|
||||||
self.serverResponse({
|
self.handle_reply({
|
||||||
|
"type": "http.response",
|
||||||
"status": status,
|
"status": status,
|
||||||
"status_text": status_text,
|
"status_text": status_text,
|
||||||
"headers": [
|
"headers": [
|
||||||
|
@ -340,83 +348,6 @@ class HTTPFactory(http.HTTPFactory):
|
||||||
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def make_send_channel(self):
|
|
||||||
"""
|
|
||||||
Makes a new send channel for a protocol with our process prefix.
|
|
||||||
"""
|
|
||||||
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
|
|
||||||
return self.send_channel + protocol_id
|
|
||||||
|
|
||||||
def dispatch_reply(self, channel, message):
|
|
||||||
# If we don't know about the channel, ignore it (likely a channel we
|
|
||||||
# used to have that's now in a group).
|
|
||||||
# TODO: Find a better way of alerting people when this happens so
|
|
||||||
# they can do more cleanup, that's not an error.
|
|
||||||
if channel not in self.reply_protocols:
|
|
||||||
logger.debug("Message on unknown channel %r - discarding" % channel)
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(self.reply_protocols[channel], WebRequest):
|
|
||||||
self.reply_protocols[channel].serverResponse(message)
|
|
||||||
elif isinstance(self.reply_protocols[channel], WebSocketProtocol):
|
|
||||||
# Switch depending on current socket state
|
|
||||||
protocol = self.reply_protocols[channel]
|
|
||||||
# See if the message is valid
|
|
||||||
unknown_keys = set(message.keys()) - {"bytes", "text", "close", "accept"}
|
|
||||||
if unknown_keys:
|
|
||||||
raise ValueError(
|
|
||||||
"Got invalid WebSocket reply message on %s - "
|
|
||||||
"contains unknown keys %s (looking for either {'accept', 'text', 'bytes', 'close'})" % (
|
|
||||||
channel,
|
|
||||||
unknown_keys,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Accepts allow bytes/text afterwards
|
|
||||||
if message.get("accept", None) and protocol.state == protocol.STATE_CONNECTING:
|
|
||||||
protocol.serverAccept()
|
|
||||||
# Rejections must be the only thing
|
|
||||||
if message.get("accept", None) == False and protocol.state == protocol.STATE_CONNECTING:
|
|
||||||
protocol.serverReject()
|
|
||||||
return
|
|
||||||
# You're only allowed one of bytes or text
|
|
||||||
if message.get("bytes", None) and message.get("text", None):
|
|
||||||
raise ValueError(
|
|
||||||
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
|
|
||||||
channel,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if message.get("bytes", None):
|
|
||||||
protocol.serverSend(message["bytes"], True)
|
|
||||||
if message.get("text", None):
|
|
||||||
protocol.serverSend(message["text"], False)
|
|
||||||
|
|
||||||
closing_code = message.get("close", False)
|
|
||||||
if closing_code:
|
|
||||||
if protocol.state == protocol.STATE_CONNECTING:
|
|
||||||
protocol.serverReject()
|
|
||||||
else:
|
|
||||||
protocol.serverClose(code=closing_code)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown protocol class")
|
|
||||||
|
|
||||||
def check_timeouts(self):
|
|
||||||
"""
|
|
||||||
Runs through all HTTP protocol instances and times them out if they've
|
|
||||||
taken too long (and so their message is probably expired)
|
|
||||||
"""
|
|
||||||
for protocol in list(self.reply_protocols.values()):
|
|
||||||
# Web timeout checking
|
|
||||||
if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout:
|
|
||||||
protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.")
|
|
||||||
# WebSocket timeout checking and keepalive ping sending
|
|
||||||
elif isinstance(protocol, WebSocketProtocol):
|
|
||||||
# Timeout check
|
|
||||||
if protocol.duration() > self.websocket_timeout and self.websocket_timeout >= 0:
|
|
||||||
protocol.serverClose()
|
|
||||||
# Ping check
|
|
||||||
else:
|
|
||||||
protocol.check_ping()
|
|
||||||
|
|
||||||
# IProtocolNegotiationFactory
|
# IProtocolNegotiationFactory
|
||||||
def acceptableProtocols(self):
|
def acceptableProtocols(self):
|
||||||
"""
|
"""
|
||||||
|
|
178
daphne/server.py
178
daphne/server.py
|
@ -1,16 +1,21 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
# This has to be done first as Twisted is import-order-sensitive with reactors
|
||||||
|
from twisted.internet import asyncioreactor
|
||||||
|
asyncioreactor.install()
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import reactor, defer
|
||||||
from twisted.internet.endpoints import serverFromString
|
from twisted.internet.endpoints import serverFromString
|
||||||
from twisted.logger import globalLogBeginner, STDLibLogObserver
|
from twisted.logger import globalLogBeginner, STDLibLogObserver
|
||||||
from twisted.web import http
|
from twisted.web import http
|
||||||
from twisted.python.threadpool import ThreadPool
|
|
||||||
|
|
||||||
from .http_protocol import HTTPFactory
|
from .http_protocol import HTTPFactory
|
||||||
from .ws_protocol import WebSocketFactory
|
from .ws_protocol import WebSocketFactory
|
||||||
|
@ -22,13 +27,12 @@ class Server(object):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channel_layer,
|
application,
|
||||||
consumer,
|
|
||||||
endpoints=None,
|
endpoints=None,
|
||||||
signal_handlers=True,
|
signal_handlers=True,
|
||||||
action_logger=None,
|
action_logger=None,
|
||||||
http_timeout=120,
|
http_timeout=120,
|
||||||
websocket_timeout=None,
|
websocket_timeout=86400,
|
||||||
websocket_connect_timeout=20,
|
websocket_connect_timeout=20,
|
||||||
ping_interval=20,
|
ping_interval=20,
|
||||||
ping_timeout=30,
|
ping_timeout=30,
|
||||||
|
@ -39,10 +43,9 @@ class Server(object):
|
||||||
verbosity=1,
|
verbosity=1,
|
||||||
websocket_handshake_timeout=5
|
websocket_handshake_timeout=5
|
||||||
):
|
):
|
||||||
self.channel_layer = channel_layer
|
self.application = application
|
||||||
self.consumer = consumer
|
|
||||||
self.endpoints = endpoints or []
|
self.endpoints = endpoints or []
|
||||||
if len(self.endpoints) == 0:
|
if not self.endpoints:
|
||||||
raise UserWarning("No endpoints. This server will not listen on anything.")
|
raise UserWarning("No endpoints. This server will not listen on anything.")
|
||||||
self.listeners = []
|
self.listeners = []
|
||||||
self.signal_handlers = signal_handlers
|
self.signal_handlers = signal_handlers
|
||||||
|
@ -52,30 +55,20 @@ class Server(object):
|
||||||
self.ping_timeout = ping_timeout
|
self.ping_timeout = ping_timeout
|
||||||
self.proxy_forwarded_address_header = proxy_forwarded_address_header
|
self.proxy_forwarded_address_header = proxy_forwarded_address_header
|
||||||
self.proxy_forwarded_port_header = proxy_forwarded_port_header
|
self.proxy_forwarded_port_header = proxy_forwarded_port_header
|
||||||
# If they did not provide a websocket timeout, default it to the
|
self.websocket_timeout = websocket_timeout
|
||||||
# channel layer's group_expiry value if present, or one day if not.
|
|
||||||
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
|
|
||||||
self.websocket_connect_timeout = websocket_connect_timeout
|
self.websocket_connect_timeout = websocket_connect_timeout
|
||||||
self.websocket_handshake_timeout = websocket_handshake_timeout
|
self.websocket_handshake_timeout = websocket_handshake_timeout
|
||||||
self.ws_protocols = ws_protocols
|
self.websocket_protocols = ws_protocols
|
||||||
self.root_path = root_path
|
self.root_path = root_path
|
||||||
self.verbosity = verbosity
|
self.verbosity = verbosity
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
# Make the thread pool to run consumers in
|
|
||||||
# TODO: Configurable numbers of threads
|
|
||||||
self.pool = ThreadPool(name="consumers")
|
|
||||||
# Make the mapping of consumer instances to consumer channels
|
|
||||||
self.consumer_instances = {}
|
|
||||||
# A set of current Twisted protocol instances to manage
|
# A set of current Twisted protocol instances to manage
|
||||||
self.protocols = set()
|
self.protocols = set()
|
||||||
# Create process-local channel prefixes
|
self.application_instances = {}
|
||||||
# TODO: Can we guarantee non-collision better?
|
|
||||||
process_id = "".join(random.choice(string.ascii_letters) for i in range(10))
|
|
||||||
self.consumer_channel_prefix = "daphne.%s!" % process_id
|
|
||||||
# Make the factory
|
# Make the factory
|
||||||
self.http_factory = HTTPFactory(self)
|
self.http_factory = HTTPFactory(self)
|
||||||
self.ws_factory = WebSocketFactory(self, protocols=self.ws_protocols, server='Daphne')
|
self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server='Daphne')
|
||||||
self.ws_factory.setProtocolOptions(
|
self.ws_factory.setProtocolOptions(
|
||||||
autoPingTimeout=self.ping_timeout,
|
autoPingTimeout=self.ping_timeout,
|
||||||
allowNullOrigin=True,
|
allowNullOrigin=True,
|
||||||
|
@ -93,18 +86,24 @@ class Server(object):
|
||||||
else:
|
else:
|
||||||
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
|
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
|
||||||
|
|
||||||
# Kick off the various background loops
|
# Kick off the timeout loop
|
||||||
reactor.callLater(0, self.backend_reader)
|
reactor.callLater(1, self.application_checker)
|
||||||
reactor.callLater(2, self.timeout_checker)
|
reactor.callLater(2, self.timeout_checker)
|
||||||
|
|
||||||
for socket_description in self.endpoints:
|
for socket_description in self.endpoints:
|
||||||
logger.info("Listening on endpoint %s" % socket_description)
|
logger.info("Listening on endpoint %s" % socket_description)
|
||||||
# Twisted requires str on python2 (not unicode) and str on python3 (not bytes)
|
|
||||||
ep = serverFromString(reactor, str(socket_description))
|
ep = serverFromString(reactor, str(socket_description))
|
||||||
self.listeners.append(ep.listen(self.http_factory))
|
self.listeners.append(ep.listen(self.http_factory))
|
||||||
|
|
||||||
self.pool.start()
|
# Set the asyncio reactor's event loop as global
|
||||||
reactor.addSystemEventTrigger("before", "shutdown", self.pool.stop)
|
# TODO: Should we instead pass the global one into the reactor?
|
||||||
|
asyncio.set_event_loop(reactor._asyncioEventloop)
|
||||||
|
|
||||||
|
# Verbosity 3 turns on asyncio debug to find those blocking yields
|
||||||
|
if self.verbosity >= 3:
|
||||||
|
asyncio.get_event_loop().set_debug(True)
|
||||||
|
|
||||||
|
reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications)
|
||||||
reactor.run(installSignalHandlers=self.signal_handlers)
|
reactor.run(installSignalHandlers=self.signal_handlers)
|
||||||
|
|
||||||
### Protocol handling
|
### Protocol handling
|
||||||
|
@ -115,83 +114,86 @@ class Server(object):
|
||||||
self.protocols.add(protocol)
|
self.protocols.add(protocol)
|
||||||
|
|
||||||
def discard_protocol(self, protocol):
|
def discard_protocol(self, protocol):
|
||||||
|
# Ensure it's not in the protocol-tracking set
|
||||||
self.protocols.discard(protocol)
|
self.protocols.discard(protocol)
|
||||||
|
# Make sure any application future that's running is cancelled
|
||||||
|
if protocol in self.application_instances:
|
||||||
|
self.application_instances[protocol].cancel()
|
||||||
|
del self.application_instances[protocol]
|
||||||
|
|
||||||
### Internal event/message handling
|
### Internal event/message handling
|
||||||
|
|
||||||
def create_consumer(self, protocol):
|
def create_application(self, protocol):
|
||||||
"""
|
"""
|
||||||
Creates a new consumer instance that fronts a Protocol instance
|
Creates a new application instance that fronts a Protocol instance
|
||||||
for one of our supported protocols. Pass it the protocol,
|
for one of our supported protocols. Pass it the protocol,
|
||||||
and it will work out the type, supply appropriate callables, and
|
and it will work out the type, supply appropriate callables, and
|
||||||
put it into the server's consumer pool.
|
return you the application's input queue
|
||||||
|
|
||||||
It returns the consumer channel name, which is how you should refer
|
|
||||||
to the consumer instance.
|
|
||||||
"""
|
"""
|
||||||
# Make sure the protocol defines a consumer type
|
# Make sure the protocol defines a application type
|
||||||
assert protocol.consumer_type is not None
|
assert protocol.application_type is not None
|
||||||
# Make it a consumer channel name
|
# Make sure the protocol has not had another application made for it
|
||||||
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
|
assert protocol not in self.application_instances
|
||||||
consumer_channel = self.consumer_channel_prefix + protocol_id
|
# Make an instance of the application
|
||||||
# Make an instance of the consumer
|
input_queue = asyncio.Queue()
|
||||||
consumer_instance = self.consumer(
|
application_instance = asyncio.ensure_future(self.application(
|
||||||
type=protocol.consumer_type,
|
type=protocol.application_type,
|
||||||
|
next=input_queue.get,
|
||||||
reply=lambda message: self.handle_reply(protocol, message),
|
reply=lambda message: self.handle_reply(protocol, message),
|
||||||
channel_layer=self.channel_layer,
|
), loop=asyncio.get_event_loop())
|
||||||
consumer_channel=consumer_channel,
|
self.application_instances[protocol] = application_instance
|
||||||
)
|
return input_queue
|
||||||
# Assign it by channel and return it
|
|
||||||
self.consumer_instances[consumer_channel] = consumer_instance
|
|
||||||
return consumer_channel
|
|
||||||
|
|
||||||
def handle_message(self, consumer_channel, message):
|
async def handle_reply(self, protocol, message):
|
||||||
"""
|
"""
|
||||||
Schedules the application instance to handle the given message.
|
Coroutine that jumps the reply message from asyncio to Twisted
|
||||||
"""
|
"""
|
||||||
self.pool.callInThread(self.consumer_instances[consumer_channel], message)
|
reactor.callLater(0, protocol.handle_reply, message)
|
||||||
|
|
||||||
def handle_reply(self, protocol, message):
|
|
||||||
"""
|
|
||||||
Schedules the reply to be handled by the protocol in the main thread
|
|
||||||
"""
|
|
||||||
reactor.callFromThread(reactor.callLater, 0, protocol.handle_reply, message)
|
|
||||||
|
|
||||||
### External event/message handling
|
|
||||||
|
|
||||||
def backend_reader(self):
|
|
||||||
"""
|
|
||||||
Runs as an-often-as-possible task with the reactor, unless there was
|
|
||||||
no result previously in which case we add a small delay.
|
|
||||||
"""
|
|
||||||
channels = [self.consumer_channel_prefix]
|
|
||||||
delay = 0
|
|
||||||
# Quit if reactor is stopping
|
|
||||||
if not reactor.running:
|
|
||||||
logger.debug("Backend reader quitting due to reactor stop")
|
|
||||||
return
|
|
||||||
# Try to receive a message
|
|
||||||
try:
|
|
||||||
channel, message = self.channel_layer.receive(channels, block=False)
|
|
||||||
except Exception as e:
|
|
||||||
# Log the error and wait a bit to retry
|
|
||||||
logger.error('Error trying to receive messages: %s' % e)
|
|
||||||
delay = 5.00
|
|
||||||
else:
|
|
||||||
if channel:
|
|
||||||
# Deal with the message
|
|
||||||
try:
|
|
||||||
self.handle_message(channel, message)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error handling external message: %s" % e)
|
|
||||||
else:
|
|
||||||
# If there's no messages, idle a little bit.
|
|
||||||
delay = 0.05
|
|
||||||
# We can't loop inside here as this is synchronous code.
|
|
||||||
reactor.callLater(delay, self.backend_reader)
|
|
||||||
|
|
||||||
### Utility
|
### Utility
|
||||||
|
|
||||||
|
def application_checker(self):
|
||||||
|
"""
|
||||||
|
Goes through the set of current application Futures and cleans up
|
||||||
|
any that are done/prints exceptions for any that errored.
|
||||||
|
"""
|
||||||
|
for protocol, application_instance in list(self.application_instances.items()):
|
||||||
|
if application_instance.done():
|
||||||
|
exception = application_instance.exception()
|
||||||
|
if exception:
|
||||||
|
logging.error(
|
||||||
|
"Exception inside application: {}\n{}{}".format(
|
||||||
|
exception,
|
||||||
|
"".join(traceback.format_tb(
|
||||||
|
exception.__traceback__,
|
||||||
|
)),
|
||||||
|
" {}".format(exception),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
protocol.handle_exception(exception)
|
||||||
|
try:
|
||||||
|
del self.application_instances[protocol]
|
||||||
|
except KeyError:
|
||||||
|
# The protocol might have already got here before us. That's fine.
|
||||||
|
pass
|
||||||
|
reactor.callLater(1, self.application_checker)
|
||||||
|
|
||||||
|
def kill_all_applications(self):
|
||||||
|
"""
|
||||||
|
Kills all application coroutines before reactor exit.
|
||||||
|
"""
|
||||||
|
# Send cancel to all coroutines
|
||||||
|
wait_for = []
|
||||||
|
for application_instance in self.application_instances.values():
|
||||||
|
if not application_instance.done():
|
||||||
|
application_instance.cancel()
|
||||||
|
wait_for.append(application_instance)
|
||||||
|
logging.info("Killed %i pending application instances" % len(wait_for))
|
||||||
|
# Make Twisted wait until they're all dead
|
||||||
|
wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for))
|
||||||
|
wait_deferred.addErrback(lambda x: None)
|
||||||
|
return wait_deferred
|
||||||
|
|
||||||
def timeout_checker(self):
|
def timeout_checker(self):
|
||||||
"""
|
"""
|
||||||
Called periodically to enforce timeout rules on all connections.
|
Called periodically to enforce timeout rules on all connections.
|
||||||
|
|
|
@ -20,20 +20,21 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
the websocket channels.
|
the websocket channels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
application_type = "websocket"
|
||||||
|
|
||||||
# If we should send no more messages (e.g. we error-closed the socket)
|
# If we should send no more messages (e.g. we error-closed the socket)
|
||||||
muted = False
|
muted = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(WebSocketProtocol, self).__init__(*args, **kwargs)
|
|
||||||
self.server = self.factory.server_class
|
|
||||||
|
|
||||||
def onConnect(self, request):
|
def onConnect(self, request):
|
||||||
|
self.server = self.factory.server_class
|
||||||
|
self.server.add_protocol(self)
|
||||||
self.request = request
|
self.request = request
|
||||||
self.packets_received = 0
|
|
||||||
self.protocol_to_accept = None
|
self.protocol_to_accept = None
|
||||||
self.socket_opened = time.time()
|
self.socket_opened = time.time()
|
||||||
self.last_data = time.time()
|
self.last_data = time.time()
|
||||||
try:
|
try:
|
||||||
|
# Make new application instance
|
||||||
|
self.application_queue = self.server.create_application(self)
|
||||||
# Sanitize and decode headers
|
# Sanitize and decode headers
|
||||||
self.clean_headers = []
|
self.clean_headers = []
|
||||||
for name, value in request.headers.items():
|
for name, value in request.headers.items():
|
||||||
|
@ -42,10 +43,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
if b"_" in name:
|
if b"_" in name:
|
||||||
continue
|
continue
|
||||||
self.clean_headers.append((name.lower(), value.encode("latin1")))
|
self.clean_headers.append((name.lower(), value.encode("latin1")))
|
||||||
# Make sending channel
|
|
||||||
self.reply_channel = self.main_factory.make_send_channel()
|
|
||||||
# Tell main factory about it
|
|
||||||
self.main_factory.reply_protocols[self.reply_channel] = self
|
|
||||||
# Get client address if possible
|
# Get client address if possible
|
||||||
peer = self.transport.getPeer()
|
peer = self.transport.getPeer()
|
||||||
host = self.transport.getHost()
|
host = self.transport.getHost()
|
||||||
|
@ -56,23 +53,23 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
self.client_addr = None
|
self.client_addr = None
|
||||||
self.server_addr = None
|
self.server_addr = None
|
||||||
|
|
||||||
if self.main_factory.proxy_forwarded_address_header:
|
if self.server.proxy_forwarded_address_header:
|
||||||
self.client_addr = parse_x_forwarded_for(
|
self.client_addr = parse_x_forwarded_for(
|
||||||
self.http_headers,
|
self.http_headers,
|
||||||
self.main_factory.proxy_forwarded_address_header,
|
self.server.proxy_forwarded_address_header,
|
||||||
self.main_factory.proxy_forwarded_port_header,
|
self.server.proxy_forwarded_port_header,
|
||||||
self.client_addr
|
self.client_addr
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make initial request info dict from request (we only have it here)
|
# Make initial request info dict from request (we only have it here)
|
||||||
self.path = request.path.encode("ascii")
|
self.path = request.path.encode("ascii")
|
||||||
self.request_info = {
|
self.connect_message = {
|
||||||
|
"type": "websocket.connect",
|
||||||
"path": self.unquote(self.path),
|
"path": self.unquote(self.path),
|
||||||
"headers": self.clean_headers,
|
"headers": self.clean_headers,
|
||||||
"query_string": self._raw_query_string, # Passed by HTTP protocol
|
"query_string": self._raw_query_string, # Passed by HTTP protocol
|
||||||
"client": self.client_addr,
|
"client": self.client_addr,
|
||||||
"server": self.server_addr,
|
"server": self.server_addr,
|
||||||
"reply_channel": self.reply_channel,
|
|
||||||
"order": 0,
|
"order": 0,
|
||||||
}
|
}
|
||||||
except:
|
except:
|
||||||
|
@ -86,50 +83,33 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
if header == b'sec-websocket-protocol':
|
if header == b'sec-websocket-protocol':
|
||||||
protocols = [x.strip() for x in self.unquote(value).split(",")]
|
protocols = [x.strip() for x in self.unquote(value).split(",")]
|
||||||
for protocol in protocols:
|
for protocol in protocols:
|
||||||
if protocol in self.factory.protocols:
|
if protocol in self.server.websocket_protocols:
|
||||||
ws_protocol = protocol
|
ws_protocol = protocol
|
||||||
break
|
break
|
||||||
|
|
||||||
# Work out what subprotocol we will accept, if any
|
# Work out what subprotocol we will accept, if any
|
||||||
if ws_protocol and ws_protocol in self.factory.protocols:
|
if ws_protocol and ws_protocol in self.server.websocket_protocols:
|
||||||
self.protocol_to_accept = ws_protocol
|
self.protocol_to_accept = ws_protocol
|
||||||
else:
|
else:
|
||||||
self.protocol_to_accept = None
|
self.protocol_to_accept = None
|
||||||
|
|
||||||
# Send over the connect message
|
# Send over the connect message
|
||||||
try:
|
self.application_queue.put_nowait(self.connect_message)
|
||||||
self.channel_layer.send("websocket.connect", self.request_info)
|
self.server.log_action("websocket", "connecting", {
|
||||||
except self.channel_layer.ChannelFull:
|
"path": self.request.path,
|
||||||
# You have to consume websocket.connect according to the spec,
|
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||||
# so drop the connection.
|
})
|
||||||
self.muted = True
|
|
||||||
logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel)
|
|
||||||
# Send code 503 "Service Unavailable" with close.
|
|
||||||
raise ConnectionDeny(code=503, reason="Connection queue at capacity")
|
|
||||||
else:
|
|
||||||
self.factory.log_action("websocket", "connecting", {
|
|
||||||
"path": self.request.path,
|
|
||||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Make a deferred and return it - we'll either call it or err it later on
|
# Make a deferred and return it - we'll either call it or err it later on
|
||||||
self.handshake_deferred = defer.Deferred()
|
self.handshake_deferred = defer.Deferred()
|
||||||
return self.handshake_deferred
|
return self.handshake_deferred
|
||||||
|
|
||||||
@classmethod
|
### Twisted event handling
|
||||||
def unquote(cls, value):
|
|
||||||
"""
|
|
||||||
Python 2 and 3 compat layer for utf-8 unquoting
|
|
||||||
"""
|
|
||||||
if six.PY2:
|
|
||||||
return unquote(value).decode("utf8")
|
|
||||||
else:
|
|
||||||
return unquote(value.decode("ascii"))
|
|
||||||
|
|
||||||
def onOpen(self):
|
def onOpen(self):
|
||||||
# Send news that this channel is open
|
# Send news that this channel is open
|
||||||
logger.debug("WebSocket %s open and established", self.reply_channel)
|
logger.debug("WebSocket %s open and established", self.client_addr)
|
||||||
self.factory.log_action("websocket", "connected", {
|
self.server.log_action("websocket", "connected", {
|
||||||
"path": self.request.path,
|
"path": self.request.path,
|
||||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||||
})
|
})
|
||||||
|
@ -137,49 +117,78 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
def onMessage(self, payload, isBinary):
|
def onMessage(self, payload, isBinary):
|
||||||
# If we're muted, do nothing.
|
# If we're muted, do nothing.
|
||||||
if self.muted:
|
if self.muted:
|
||||||
logger.debug("Muting incoming frame on %s", self.reply_channel)
|
logger.debug("Muting incoming frame on %s", self.client_addr)
|
||||||
return
|
return
|
||||||
logger.debug("WebSocket incoming frame on %s", self.reply_channel)
|
logger.debug("WebSocket incoming frame on %s", self.client_addr)
|
||||||
self.packets_received += 1
|
|
||||||
self.last_data = time.time()
|
self.last_data = time.time()
|
||||||
try:
|
if isBinary:
|
||||||
if isBinary:
|
self.application_queue.put_nowait({
|
||||||
self.channel_layer.send("websocket.receive", {
|
"type": "websocket.receive",
|
||||||
"reply_channel": self.reply_channel,
|
"bytes": payload,
|
||||||
"path": self.unquote(self.path),
|
})
|
||||||
"order": self.packets_received,
|
else:
|
||||||
"bytes": payload,
|
self.application_queue.put_nowait({
|
||||||
})
|
"type": "websocket.receive",
|
||||||
|
"text": payload.decode("utf8"),
|
||||||
|
})
|
||||||
|
|
||||||
|
def onClose(self, wasClean, code, reason):
|
||||||
|
"""
|
||||||
|
Called when Twisted closes the socket.
|
||||||
|
"""
|
||||||
|
self.server.discard_protocol(self)
|
||||||
|
logger.debug("WebSocket closed for %s", self.client_addr)
|
||||||
|
if not self.muted:
|
||||||
|
self.application_queue.put_nowait({
|
||||||
|
"type": "websocket.disconnect",
|
||||||
|
"code": code,
|
||||||
|
})
|
||||||
|
self.server.log_action("websocket", "disconnected", {
|
||||||
|
"path": self.request.path,
|
||||||
|
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
### Internal event handling
|
||||||
|
|
||||||
|
def handle_reply(self, message):
|
||||||
|
if "type" not in message:
|
||||||
|
raise ValueError("Message has no type defined")
|
||||||
|
if message["type"] == "websocket.accept":
|
||||||
|
self.serverAccept()
|
||||||
|
elif message["type"] == "websocket.close":
|
||||||
|
if self.state == self.STATE_CONNECTING:
|
||||||
|
self.serverReject()
|
||||||
else:
|
else:
|
||||||
self.channel_layer.send("websocket.receive", {
|
self.serverClose(code=message.get("code", None))
|
||||||
"reply_channel": self.reply_channel,
|
elif message["type"] == "websocket.send":
|
||||||
"path": self.unquote(self.path),
|
if self.state == self.STATE_CONNECTING:
|
||||||
"order": self.packets_received,
|
raise ValueError("Socket has not been accepted, so cannot send over it")
|
||||||
"text": payload.decode("utf8"),
|
if message.get("bytes", None) and message.get("text", None):
|
||||||
})
|
raise ValueError(
|
||||||
except self.channel_layer.ChannelFull:
|
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
|
||||||
# You have to consume websocket.receive according to the spec,
|
channel,
|
||||||
# so drop the connection.
|
)
|
||||||
self.muted = True
|
)
|
||||||
logger.warn("WebSocket force closed for %s due to receive backpressure", self.reply_channel)
|
if message.get("bytes", None):
|
||||||
# Send code 1013 "try again later" with close.
|
self.serverSend(message["bytes"], True)
|
||||||
self.sendCloseFrame(code=1013, isReply=False)
|
if message.get("text", None):
|
||||||
|
self.serverSend(message["text"], False)
|
||||||
|
|
||||||
def serverAccept(self):
|
def serverAccept(self):
|
||||||
"""
|
"""
|
||||||
Called when we get a message saying to accept the connection.
|
Called when we get a message saying to accept the connection.
|
||||||
"""
|
"""
|
||||||
self.handshake_deferred.callback(self.protocol_to_accept)
|
self.handshake_deferred.callback(self.protocol_to_accept)
|
||||||
logger.debug("WebSocket %s accepted by application", self.reply_channel)
|
logger.debug("WebSocket %s accepted by application", self.client_addr)
|
||||||
|
|
||||||
def serverReject(self):
|
def serverReject(self):
|
||||||
"""
|
"""
|
||||||
Called when we get a message saying to reject the connection.
|
Called when we get a message saying to reject the connection.
|
||||||
"""
|
"""
|
||||||
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
|
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
|
||||||
self.cleanup()
|
self.server.discard_protocol(self)
|
||||||
logger.debug("WebSocket %s rejected by application", self.reply_channel)
|
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
||||||
self.factory.log_action("websocket", "rejected", {
|
self.server.log_action("websocket", "rejected", {
|
||||||
"path": self.request.path,
|
"path": self.request.path,
|
||||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||||
})
|
})
|
||||||
|
@ -191,47 +200,30 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
if self.state == self.STATE_CONNECTING:
|
if self.state == self.STATE_CONNECTING:
|
||||||
self.serverAccept()
|
self.serverAccept()
|
||||||
self.last_data = time.time()
|
self.last_data = time.time()
|
||||||
logger.debug("Sent WebSocket packet to client for %s", self.reply_channel)
|
logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
|
||||||
if binary:
|
if binary:
|
||||||
self.sendMessage(content, binary)
|
self.sendMessage(content, binary)
|
||||||
else:
|
else:
|
||||||
self.sendMessage(content.encode("utf8"), binary)
|
self.sendMessage(content.encode("utf8"), binary)
|
||||||
|
|
||||||
def serverClose(self, code=True):
|
def serverClose(self, code=None):
|
||||||
"""
|
"""
|
||||||
Server-side channel message to close the socket
|
Server-side channel message to close the socket
|
||||||
"""
|
"""
|
||||||
code = 1000 if code is True else code
|
code = 1000 if code is None else code
|
||||||
self.sendClose(code=code)
|
self.sendClose(code=code)
|
||||||
|
|
||||||
def onClose(self, wasClean, code, reason):
|
### Utils
|
||||||
self.cleanup()
|
|
||||||
if hasattr(self, "reply_channel"):
|
|
||||||
logger.debug("WebSocket closed for %s", self.reply_channel)
|
|
||||||
try:
|
|
||||||
if not self.muted:
|
|
||||||
self.channel_layer.send("websocket.disconnect", {
|
|
||||||
"reply_channel": self.reply_channel,
|
|
||||||
"code": code,
|
|
||||||
"path": self.unquote(self.path),
|
|
||||||
"order": self.packets_received + 1,
|
|
||||||
})
|
|
||||||
except self.channel_layer.ChannelFull:
|
|
||||||
pass
|
|
||||||
self.factory.log_action("websocket", "disconnected", {
|
|
||||||
"path": self.request.path,
|
|
||||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
logger.debug("WebSocket closed before handshake established")
|
|
||||||
|
|
||||||
def cleanup(self):
|
@classmethod
|
||||||
|
def unquote(cls, value):
|
||||||
"""
|
"""
|
||||||
Call to clean up this socket after it's closed.
|
Python 2 and 3 compat layer for utf-8 unquoting
|
||||||
"""
|
"""
|
||||||
if hasattr(self, "reply_channel"):
|
if six.PY2:
|
||||||
if self.reply_channel in self.factory.reply_protocols:
|
return unquote(value).decode("utf8")
|
||||||
del self.factory.reply_protocols[self.reply_channel]
|
else:
|
||||||
|
return unquote(value.decode("ascii"))
|
||||||
|
|
||||||
def duration(self):
|
def duration(self):
|
||||||
"""
|
"""
|
||||||
|
@ -275,3 +267,15 @@ class WebSocketFactory(WebSocketServerFactory):
|
||||||
def __init__(self, server_class, *args, **kwargs):
|
def __init__(self, server_class, *args, **kwargs):
|
||||||
self.server_class = server_class
|
self.server_class = server_class
|
||||||
WebSocketServerFactory.__init__(self, *args, **kwargs)
|
WebSocketServerFactory.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
|
def buildProtocol(self, addr):
|
||||||
|
"""
|
||||||
|
Builds protocol instances. We use this to inject the factory object into the protocol.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
protocol = super(WebSocketFactory, self).buildProtocol(addr)
|
||||||
|
protocol.factory = self
|
||||||
|
return protocol
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
Loading…
Reference in New Issue
Block a user