Trying out asyncio based interface

This commit is contained in:
Andrew Godwin 2017-09-07 21:24:14 -07:00
parent a656c9f4c6
commit 01f174bf26
5 changed files with 224 additions and 292 deletions

View File

@ -131,12 +131,8 @@ class CommandLineInterface(object):
default=2, default=2,
) )
self.parser.add_argument( self.parser.add_argument(
'channel_layer', 'application',
help='The ASGI channel layer instance to use as path.to.module:instance.path', help='The application to dispatch to as path.to.module:instance.path',
)
self.parser.add_argument(
'consumer',
help='The consumer to dispatch to as path.to.module:instance.path',
) )
self.server = None self.server = None
@ -161,6 +157,7 @@ class CommandLineInterface(object):
0: logging.WARN, 0: logging.WARN,
1: logging.INFO, 1: logging.INFO,
2: logging.DEBUG, 2: logging.DEBUG,
3: logging.DEBUG, # Also turns on asyncio debug
}[args.verbosity], }[args.verbosity],
format="%(asctime)-15s %(levelname)-8s %(message)s", format="%(asctime)-15s %(levelname)-8s %(message)s",
) )
@ -173,10 +170,9 @@ class CommandLineInterface(object):
access_log_stream = open(args.access_log, "a", 1) access_log_stream = open(args.access_log, "a", 1)
elif args.verbosity >= 1: elif args.verbosity >= 1:
access_log_stream = sys.stdout access_log_stream = sys.stdout
# Import application and channel_layer # Import application
sys.path.insert(0, ".") sys.path.insert(0, ".")
channel_layer = import_by_path(args.channel_layer) application = import_by_path(args.application)
consumer = import_by_path(args.consumer)
# Set up port/host bindings # Set up port/host bindings
if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]): if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]):
# no advanced binding options passed, patch in defaults # no advanced binding options passed, patch in defaults
@ -198,12 +194,11 @@ class CommandLineInterface(object):
) )
# Start the server # Start the server
logger.info( logger.info(
'Starting server at %s, channel layer %s.' % 'Starting server at %s' %
(', '.join(endpoints), args.channel_layer) (', '.join(endpoints), )
) )
self.server = Server( self.server = Server(
channel_layer=channel_layer, application=application,
consumer=consumer,
endpoints=endpoints, endpoints=endpoints,
http_timeout=args.http_timeout, http_timeout=args.http_timeout,
ping_interval=args.ping_interval, ping_interval=args.ping_interval,

View File

@ -10,6 +10,7 @@ import traceback
from zope.interface import implementer from zope.interface import implementer
from six.moves.urllib_parse import unquote, unquote_plus from six.moves.urllib_parse import unquote, unquote_plus
from twisted.internet.defer import ensureDeferred
from twisted.internet.interfaces import IProtocolNegotiationFactory from twisted.internet.interfaces import IProtocolNegotiationFactory
from twisted.protocols.policies import ProtocolWrapper from twisted.protocols.policies import ProtocolWrapper
from twisted.web import http from twisted.web import http
@ -29,7 +30,7 @@ class WebRequest(http.Request):
GET and POST out. GET and POST out.
""" """
consumer_type = "http" application_type = "http"
error_template = """ error_template = """
<html> <html>
@ -55,7 +56,7 @@ class WebRequest(http.Request):
http.Request.__init__(self, *args, **kwargs) http.Request.__init__(self, *args, **kwargs)
# Easy server link # Easy server link
self.server = self.channel.factory.server self.server = self.channel.factory.server
self.consumer_channel = None self.application_queue = None
self._response_started = False self._response_started = False
self.server.add_protocol(self) self.server.add_protocol(self)
except Exception: except Exception:
@ -137,8 +138,8 @@ class WebRequest(http.Request):
# Boring old HTTP. # Boring old HTTP.
else: else:
# Create consumer to handle this connection # Create application to handle this connection
self.consumer_channel = self.server.create_consumer(self) self.application_queue = self.server.create_application(self)
# Sanitize and decode headers, potentially extracting root path # Sanitize and decode headers, potentially extracting root path
self.clean_headers = [] self.clean_headers = []
self.root_path = self.server.root_path self.root_path = self.server.root_path
@ -151,11 +152,10 @@ class WebRequest(http.Request):
self.root_path = self.unquote(value) self.root_path = self.unquote(value)
else: else:
self.clean_headers.append((name.lower(), value)) self.clean_headers.append((name.lower(), value))
logger.debug("HTTP %s request for %s", self.method, self.consumer_channel) logger.debug("HTTP %s request for %s", self.method, self.client_addr)
self.content.seek(0, 0) self.content.seek(0, 0)
# Run consumer # Run application against request
self.server.handle_message( self.application_queue.put_nowait(
self.consumer_channel,
{ {
"type": "http.request", "type": "http.request",
# TODO: Correctly say if it's 1.1 or 1.0 # TODO: Correctly say if it's 1.1 or 1.0
@ -173,13 +173,13 @@ class WebRequest(http.Request):
) )
except Exception: except Exception:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self.basic_error(500, b"Internal Server Error", "HTTP processing error") self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error")
def connectionLost(self, reason): def connectionLost(self, reason):
""" """
Cleans up reply channel on close. Cleans up reply channel on close.
""" """
if self.consumer_channel: if self.application_queue:
self.send_disconnect() self.send_disconnect()
logger.debug("HTTP disconnect for %s", self.client_addr) logger.debug("HTTP disconnect for %s", self.client_addr)
http.Request.connectionLost(self, reason) http.Request.connectionLost(self, reason)
@ -189,7 +189,7 @@ class WebRequest(http.Request):
""" """
Cleans up reply channel on close. Cleans up reply channel on close.
""" """
if self.consumer_channel: if self.application_queue:
self.send_disconnect() self.send_disconnect()
logger.debug("HTTP close for %s", self.client_addr) logger.debug("HTTP close for %s", self.client_addr)
http.Request.finish(self) http.Request.finish(self)
@ -201,6 +201,8 @@ class WebRequest(http.Request):
""" """
Handles a reply from the client Handles a reply from the client
""" """
if "type" not in message:
raise ValueError("Message has no type defined")
if message['type'] == "http.response": if message['type'] == "http.response":
if self._response_started: if self._response_started:
raise ValueError("HTTP response has already been started") raise ValueError("HTTP response has already been started")
@ -216,9 +218,9 @@ class WebRequest(http.Request):
header = header.encode("latin1") header = header.encode("latin1")
self.responseHeaders.addRawHeader(header, value) self.responseHeaders.addRawHeader(header, value)
logger.debug("HTTP %s response started for %s", message['status'], self.client_addr) logger.debug("HTTP %s response started for %s", message['status'], self.client_addr)
elif message['type'] == "http.response.content": elif message['type'] == "http.response.hunk":
if not self._response_started: if not self._response_started:
raise ValueError("HTTP response has not yet been started") raise ValueError("HTTP response has not yet been started but got %s" % message['type'])
else: else:
raise ValueError("Cannot handle message type %s!" % message['type']) raise ValueError("Cannot handle message type %s!" % message['type'])
@ -243,6 +245,12 @@ class WebRequest(http.Request):
else: else:
logger.debug("HTTP response chunk for %s", self.client_addr) logger.debug("HTTP response chunk for %s", self.client_addr)
def handle_exception(self, exception):
"""
Called by the server when our application tracebacks
"""
self.basic_error(500, b"Internal Server Error", "Exception inside application.")
def check_timeouts(self): def check_timeouts(self):
""" """
Called periodically to see if we should timeout something Called periodically to see if we should timeout something
@ -276,8 +284,7 @@ class WebRequest(http.Request):
""" """
# If we don't yet have a path, then don't send as we never opened. # If we don't yet have a path, then don't send as we never opened.
if self.path: if self.path:
self.server.handle_message( self.application_queue.put_nowait(
self.consumer_channel,
{ {
"type": "http.disconnect", "type": "http.disconnect",
}, },
@ -295,7 +302,8 @@ class WebRequest(http.Request):
""" """
Responds with a server-level error page (very basic) Responds with a server-level error page (very basic)
""" """
self.serverResponse({ self.handle_reply({
"type": "http.response",
"status": status, "status": status,
"status_text": status_text, "status_text": status_text,
"headers": [ "headers": [
@ -340,83 +348,6 @@ class HTTPFactory(http.HTTPFactory):
logger.error("Cannot build protocol: %s" % traceback.format_exc()) logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise raise
def make_send_channel(self):
"""
Makes a new send channel for a protocol with our process prefix.
"""
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
return self.send_channel + protocol_id
def dispatch_reply(self, channel, message):
# If we don't know about the channel, ignore it (likely a channel we
# used to have that's now in a group).
# TODO: Find a better way of alerting people when this happens so
# they can do more cleanup, that's not an error.
if channel not in self.reply_protocols:
logger.debug("Message on unknown channel %r - discarding" % channel)
return
if isinstance(self.reply_protocols[channel], WebRequest):
self.reply_protocols[channel].serverResponse(message)
elif isinstance(self.reply_protocols[channel], WebSocketProtocol):
# Switch depending on current socket state
protocol = self.reply_protocols[channel]
# See if the message is valid
unknown_keys = set(message.keys()) - {"bytes", "text", "close", "accept"}
if unknown_keys:
raise ValueError(
"Got invalid WebSocket reply message on %s - "
"contains unknown keys %s (looking for either {'accept', 'text', 'bytes', 'close'})" % (
channel,
unknown_keys,
)
)
# Accepts allow bytes/text afterwards
if message.get("accept", None) and protocol.state == protocol.STATE_CONNECTING:
protocol.serverAccept()
# Rejections must be the only thing
if message.get("accept", None) == False and protocol.state == protocol.STATE_CONNECTING:
protocol.serverReject()
return
# You're only allowed one of bytes or text
if message.get("bytes", None) and message.get("text", None):
raise ValueError(
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
channel,
)
)
if message.get("bytes", None):
protocol.serverSend(message["bytes"], True)
if message.get("text", None):
protocol.serverSend(message["text"], False)
closing_code = message.get("close", False)
if closing_code:
if protocol.state == protocol.STATE_CONNECTING:
protocol.serverReject()
else:
protocol.serverClose(code=closing_code)
else:
raise ValueError("Unknown protocol class")
def check_timeouts(self):
"""
Runs through all HTTP protocol instances and times them out if they've
taken too long (and so their message is probably expired)
"""
for protocol in list(self.reply_protocols.values()):
# Web timeout checking
if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout:
protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.")
# WebSocket timeout checking and keepalive ping sending
elif isinstance(protocol, WebSocketProtocol):
# Timeout check
if protocol.duration() > self.websocket_timeout and self.websocket_timeout >= 0:
protocol.serverClose()
# Ping check
else:
protocol.check_ping()
# IProtocolNegotiationFactory # IProtocolNegotiationFactory
def acceptableProtocols(self): def acceptableProtocols(self):
""" """

View File

@ -1,16 +1,21 @@
from __future__ import unicode_literals from __future__ import unicode_literals
# This has to be done first as Twisted is import-order-sensitive with reactors
from twisted.internet import asyncioreactor
asyncioreactor.install()
import asyncio
import collections import collections
import logging import logging
import random import random
import string import string
import traceback
import warnings import warnings
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.internet.endpoints import serverFromString from twisted.internet.endpoints import serverFromString
from twisted.logger import globalLogBeginner, STDLibLogObserver from twisted.logger import globalLogBeginner, STDLibLogObserver
from twisted.web import http from twisted.web import http
from twisted.python.threadpool import ThreadPool
from .http_protocol import HTTPFactory from .http_protocol import HTTPFactory
from .ws_protocol import WebSocketFactory from .ws_protocol import WebSocketFactory
@ -22,13 +27,12 @@ class Server(object):
def __init__( def __init__(
self, self,
channel_layer, application,
consumer,
endpoints=None, endpoints=None,
signal_handlers=True, signal_handlers=True,
action_logger=None, action_logger=None,
http_timeout=120, http_timeout=120,
websocket_timeout=None, websocket_timeout=86400,
websocket_connect_timeout=20, websocket_connect_timeout=20,
ping_interval=20, ping_interval=20,
ping_timeout=30, ping_timeout=30,
@ -39,10 +43,9 @@ class Server(object):
verbosity=1, verbosity=1,
websocket_handshake_timeout=5 websocket_handshake_timeout=5
): ):
self.channel_layer = channel_layer self.application = application
self.consumer = consumer
self.endpoints = endpoints or [] self.endpoints = endpoints or []
if len(self.endpoints) == 0: if not self.endpoints:
raise UserWarning("No endpoints. This server will not listen on anything.") raise UserWarning("No endpoints. This server will not listen on anything.")
self.listeners = [] self.listeners = []
self.signal_handlers = signal_handlers self.signal_handlers = signal_handlers
@ -52,30 +55,20 @@ class Server(object):
self.ping_timeout = ping_timeout self.ping_timeout = ping_timeout
self.proxy_forwarded_address_header = proxy_forwarded_address_header self.proxy_forwarded_address_header = proxy_forwarded_address_header
self.proxy_forwarded_port_header = proxy_forwarded_port_header self.proxy_forwarded_port_header = proxy_forwarded_port_header
# If they did not provide a websocket timeout, default it to the self.websocket_timeout = websocket_timeout
# channel layer's group_expiry value if present, or one day if not.
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
self.websocket_connect_timeout = websocket_connect_timeout self.websocket_connect_timeout = websocket_connect_timeout
self.websocket_handshake_timeout = websocket_handshake_timeout self.websocket_handshake_timeout = websocket_handshake_timeout
self.ws_protocols = ws_protocols self.websocket_protocols = ws_protocols
self.root_path = root_path self.root_path = root_path
self.verbosity = verbosity self.verbosity = verbosity
def run(self): def run(self):
# Make the thread pool to run consumers in
# TODO: Configurable numbers of threads
self.pool = ThreadPool(name="consumers")
# Make the mapping of consumer instances to consumer channels
self.consumer_instances = {}
# A set of current Twisted protocol instances to manage # A set of current Twisted protocol instances to manage
self.protocols = set() self.protocols = set()
# Create process-local channel prefixes self.application_instances = {}
# TODO: Can we guarantee non-collision better?
process_id = "".join(random.choice(string.ascii_letters) for i in range(10))
self.consumer_channel_prefix = "daphne.%s!" % process_id
# Make the factory # Make the factory
self.http_factory = HTTPFactory(self) self.http_factory = HTTPFactory(self)
self.ws_factory = WebSocketFactory(self, protocols=self.ws_protocols, server='Daphne') self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server='Daphne')
self.ws_factory.setProtocolOptions( self.ws_factory.setProtocolOptions(
autoPingTimeout=self.ping_timeout, autoPingTimeout=self.ping_timeout,
allowNullOrigin=True, allowNullOrigin=True,
@ -93,18 +86,24 @@ class Server(object):
else: else:
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)") logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
# Kick off the various background loops # Kick off the timeout loop
reactor.callLater(0, self.backend_reader) reactor.callLater(1, self.application_checker)
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)
for socket_description in self.endpoints: for socket_description in self.endpoints:
logger.info("Listening on endpoint %s" % socket_description) logger.info("Listening on endpoint %s" % socket_description)
# Twisted requires str on python2 (not unicode) and str on python3 (not bytes)
ep = serverFromString(reactor, str(socket_description)) ep = serverFromString(reactor, str(socket_description))
self.listeners.append(ep.listen(self.http_factory)) self.listeners.append(ep.listen(self.http_factory))
self.pool.start() # Set the asyncio reactor's event loop as global
reactor.addSystemEventTrigger("before", "shutdown", self.pool.stop) # TODO: Should we instead pass the global one into the reactor?
asyncio.set_event_loop(reactor._asyncioEventloop)
# Verbosity 3 turns on asyncio debug to find those blocking yields
if self.verbosity >= 3:
asyncio.get_event_loop().set_debug(True)
reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications)
reactor.run(installSignalHandlers=self.signal_handlers) reactor.run(installSignalHandlers=self.signal_handlers)
### Protocol handling ### Protocol handling
@ -115,83 +114,86 @@ class Server(object):
self.protocols.add(protocol) self.protocols.add(protocol)
def discard_protocol(self, protocol): def discard_protocol(self, protocol):
# Ensure it's not in the protocol-tracking set
self.protocols.discard(protocol) self.protocols.discard(protocol)
# Make sure any application future that's running is cancelled
if protocol in self.application_instances:
self.application_instances[protocol].cancel()
del self.application_instances[protocol]
### Internal event/message handling ### Internal event/message handling
def create_consumer(self, protocol): def create_application(self, protocol):
""" """
Creates a new consumer instance that fronts a Protocol instance Creates a new application instance that fronts a Protocol instance
for one of our supported protocols. Pass it the protocol, for one of our supported protocols. Pass it the protocol,
and it will work out the type, supply appropriate callables, and and it will work out the type, supply appropriate callables, and
put it into the server's consumer pool. return you the application's input queue
It returns the consumer channel name, which is how you should refer
to the consumer instance.
""" """
# Make sure the protocol defines a consumer type # Make sure the protocol defines a application type
assert protocol.consumer_type is not None assert protocol.application_type is not None
# Make it a consumer channel name # Make sure the protocol has not had another application made for it
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10)) assert protocol not in self.application_instances
consumer_channel = self.consumer_channel_prefix + protocol_id # Make an instance of the application
# Make an instance of the consumer input_queue = asyncio.Queue()
consumer_instance = self.consumer( application_instance = asyncio.ensure_future(self.application(
type=protocol.consumer_type, type=protocol.application_type,
next=input_queue.get,
reply=lambda message: self.handle_reply(protocol, message), reply=lambda message: self.handle_reply(protocol, message),
channel_layer=self.channel_layer, ), loop=asyncio.get_event_loop())
consumer_channel=consumer_channel, self.application_instances[protocol] = application_instance
) return input_queue
# Assign it by channel and return it
self.consumer_instances[consumer_channel] = consumer_instance
return consumer_channel
def handle_message(self, consumer_channel, message): async def handle_reply(self, protocol, message):
""" """
Schedules the application instance to handle the given message. Coroutine that jumps the reply message from asyncio to Twisted
""" """
self.pool.callInThread(self.consumer_instances[consumer_channel], message) reactor.callLater(0, protocol.handle_reply, message)
def handle_reply(self, protocol, message):
"""
Schedules the reply to be handled by the protocol in the main thread
"""
reactor.callFromThread(reactor.callLater, 0, protocol.handle_reply, message)
### External event/message handling
def backend_reader(self):
"""
Runs as an-often-as-possible task with the reactor, unless there was
no result previously in which case we add a small delay.
"""
channels = [self.consumer_channel_prefix]
delay = 0
# Quit if reactor is stopping
if not reactor.running:
logger.debug("Backend reader quitting due to reactor stop")
return
# Try to receive a message
try:
channel, message = self.channel_layer.receive(channels, block=False)
except Exception as e:
# Log the error and wait a bit to retry
logger.error('Error trying to receive messages: %s' % e)
delay = 5.00
else:
if channel:
# Deal with the message
try:
self.handle_message(channel, message)
except Exception as e:
logger.error("Error handling external message: %s" % e)
else:
# If there's no messages, idle a little bit.
delay = 0.05
# We can't loop inside here as this is synchronous code.
reactor.callLater(delay, self.backend_reader)
### Utility ### Utility
def application_checker(self):
"""
Goes through the set of current application Futures and cleans up
any that are done/prints exceptions for any that errored.
"""
for protocol, application_instance in list(self.application_instances.items()):
if application_instance.done():
exception = application_instance.exception()
if exception:
logging.error(
"Exception inside application: {}\n{}{}".format(
exception,
"".join(traceback.format_tb(
exception.__traceback__,
)),
" {}".format(exception),
)
)
protocol.handle_exception(exception)
try:
del self.application_instances[protocol]
except KeyError:
# The protocol might have already got here before us. That's fine.
pass
reactor.callLater(1, self.application_checker)
def kill_all_applications(self):
"""
Kills all application coroutines before reactor exit.
"""
# Send cancel to all coroutines
wait_for = []
for application_instance in self.application_instances.values():
if not application_instance.done():
application_instance.cancel()
wait_for.append(application_instance)
logging.info("Killed %i pending application instances" % len(wait_for))
# Make Twisted wait until they're all dead
wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for))
wait_deferred.addErrback(lambda x: None)
return wait_deferred
def timeout_checker(self): def timeout_checker(self):
""" """
Called periodically to enforce timeout rules on all connections. Called periodically to enforce timeout rules on all connections.

View File

@ -20,20 +20,21 @@ class WebSocketProtocol(WebSocketServerProtocol):
the websocket channels. the websocket channels.
""" """
application_type = "websocket"
# If we should send no more messages (e.g. we error-closed the socket) # If we should send no more messages (e.g. we error-closed the socket)
muted = False muted = False
def __init__(self, *args, **kwargs):
super(WebSocketProtocol, self).__init__(*args, **kwargs)
self.server = self.factory.server_class
def onConnect(self, request): def onConnect(self, request):
self.server = self.factory.server_class
self.server.add_protocol(self)
self.request = request self.request = request
self.packets_received = 0
self.protocol_to_accept = None self.protocol_to_accept = None
self.socket_opened = time.time() self.socket_opened = time.time()
self.last_data = time.time() self.last_data = time.time()
try: try:
# Make new application instance
self.application_queue = self.server.create_application(self)
# Sanitize and decode headers # Sanitize and decode headers
self.clean_headers = [] self.clean_headers = []
for name, value in request.headers.items(): for name, value in request.headers.items():
@ -42,10 +43,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
if b"_" in name: if b"_" in name:
continue continue
self.clean_headers.append((name.lower(), value.encode("latin1"))) self.clean_headers.append((name.lower(), value.encode("latin1")))
# Make sending channel
self.reply_channel = self.main_factory.make_send_channel()
# Tell main factory about it
self.main_factory.reply_protocols[self.reply_channel] = self
# Get client address if possible # Get client address if possible
peer = self.transport.getPeer() peer = self.transport.getPeer()
host = self.transport.getHost() host = self.transport.getHost()
@ -56,23 +53,23 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.client_addr = None self.client_addr = None
self.server_addr = None self.server_addr = None
if self.main_factory.proxy_forwarded_address_header: if self.server.proxy_forwarded_address_header:
self.client_addr = parse_x_forwarded_for( self.client_addr = parse_x_forwarded_for(
self.http_headers, self.http_headers,
self.main_factory.proxy_forwarded_address_header, self.server.proxy_forwarded_address_header,
self.main_factory.proxy_forwarded_port_header, self.server.proxy_forwarded_port_header,
self.client_addr self.client_addr
) )
# Make initial request info dict from request (we only have it here) # Make initial request info dict from request (we only have it here)
self.path = request.path.encode("ascii") self.path = request.path.encode("ascii")
self.request_info = { self.connect_message = {
"type": "websocket.connect",
"path": self.unquote(self.path), "path": self.unquote(self.path),
"headers": self.clean_headers, "headers": self.clean_headers,
"query_string": self._raw_query_string, # Passed by HTTP protocol "query_string": self._raw_query_string, # Passed by HTTP protocol
"client": self.client_addr, "client": self.client_addr,
"server": self.server_addr, "server": self.server_addr,
"reply_channel": self.reply_channel,
"order": 0, "order": 0,
} }
except: except:
@ -86,50 +83,33 @@ class WebSocketProtocol(WebSocketServerProtocol):
if header == b'sec-websocket-protocol': if header == b'sec-websocket-protocol':
protocols = [x.strip() for x in self.unquote(value).split(",")] protocols = [x.strip() for x in self.unquote(value).split(",")]
for protocol in protocols: for protocol in protocols:
if protocol in self.factory.protocols: if protocol in self.server.websocket_protocols:
ws_protocol = protocol ws_protocol = protocol
break break
# Work out what subprotocol we will accept, if any # Work out what subprotocol we will accept, if any
if ws_protocol and ws_protocol in self.factory.protocols: if ws_protocol and ws_protocol in self.server.websocket_protocols:
self.protocol_to_accept = ws_protocol self.protocol_to_accept = ws_protocol
else: else:
self.protocol_to_accept = None self.protocol_to_accept = None
# Send over the connect message # Send over the connect message
try: self.application_queue.put_nowait(self.connect_message)
self.channel_layer.send("websocket.connect", self.request_info) self.server.log_action("websocket", "connecting", {
except self.channel_layer.ChannelFull: "path": self.request.path,
# You have to consume websocket.connect according to the spec, "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
# so drop the connection. })
self.muted = True
logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel)
# Send code 503 "Service Unavailable" with close.
raise ConnectionDeny(code=503, reason="Connection queue at capacity")
else:
self.factory.log_action("websocket", "connecting", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
# Make a deferred and return it - we'll either call it or err it later on # Make a deferred and return it - we'll either call it or err it later on
self.handshake_deferred = defer.Deferred() self.handshake_deferred = defer.Deferred()
return self.handshake_deferred return self.handshake_deferred
@classmethod ### Twisted event handling
def unquote(cls, value):
"""
Python 2 and 3 compat layer for utf-8 unquoting
"""
if six.PY2:
return unquote(value).decode("utf8")
else:
return unquote(value.decode("ascii"))
def onOpen(self): def onOpen(self):
# Send news that this channel is open # Send news that this channel is open
logger.debug("WebSocket %s open and established", self.reply_channel) logger.debug("WebSocket %s open and established", self.client_addr)
self.factory.log_action("websocket", "connected", { self.server.log_action("websocket", "connected", {
"path": self.request.path, "path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
}) })
@ -137,49 +117,78 @@ class WebSocketProtocol(WebSocketServerProtocol):
def onMessage(self, payload, isBinary): def onMessage(self, payload, isBinary):
# If we're muted, do nothing. # If we're muted, do nothing.
if self.muted: if self.muted:
logger.debug("Muting incoming frame on %s", self.reply_channel) logger.debug("Muting incoming frame on %s", self.client_addr)
return return
logger.debug("WebSocket incoming frame on %s", self.reply_channel) logger.debug("WebSocket incoming frame on %s", self.client_addr)
self.packets_received += 1
self.last_data = time.time() self.last_data = time.time()
try: if isBinary:
if isBinary: self.application_queue.put_nowait({
self.channel_layer.send("websocket.receive", { "type": "websocket.receive",
"reply_channel": self.reply_channel, "bytes": payload,
"path": self.unquote(self.path), })
"order": self.packets_received, else:
"bytes": payload, self.application_queue.put_nowait({
}) "type": "websocket.receive",
"text": payload.decode("utf8"),
})
def onClose(self, wasClean, code, reason):
"""
Called when Twisted closes the socket.
"""
self.server.discard_protocol(self)
logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted:
self.application_queue.put_nowait({
"type": "websocket.disconnect",
"code": code,
})
self.server.log_action("websocket", "disconnected", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
### Internal event handling
def handle_reply(self, message):
if "type" not in message:
raise ValueError("Message has no type defined")
if message["type"] == "websocket.accept":
self.serverAccept()
elif message["type"] == "websocket.close":
if self.state == self.STATE_CONNECTING:
self.serverReject()
else: else:
self.channel_layer.send("websocket.receive", { self.serverClose(code=message.get("code", None))
"reply_channel": self.reply_channel, elif message["type"] == "websocket.send":
"path": self.unquote(self.path), if self.state == self.STATE_CONNECTING:
"order": self.packets_received, raise ValueError("Socket has not been accepted, so cannot send over it")
"text": payload.decode("utf8"), if message.get("bytes", None) and message.get("text", None):
}) raise ValueError(
except self.channel_layer.ChannelFull: "Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
# You have to consume websocket.receive according to the spec, channel,
# so drop the connection. )
self.muted = True )
logger.warn("WebSocket force closed for %s due to receive backpressure", self.reply_channel) if message.get("bytes", None):
# Send code 1013 "try again later" with close. self.serverSend(message["bytes"], True)
self.sendCloseFrame(code=1013, isReply=False) if message.get("text", None):
self.serverSend(message["text"], False)
def serverAccept(self): def serverAccept(self):
""" """
Called when we get a message saying to accept the connection. Called when we get a message saying to accept the connection.
""" """
self.handshake_deferred.callback(self.protocol_to_accept) self.handshake_deferred.callback(self.protocol_to_accept)
logger.debug("WebSocket %s accepted by application", self.reply_channel) logger.debug("WebSocket %s accepted by application", self.client_addr)
def serverReject(self): def serverReject(self):
""" """
Called when we get a message saying to reject the connection. Called when we get a message saying to reject the connection.
""" """
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied")) self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
self.cleanup() self.server.discard_protocol(self)
logger.debug("WebSocket %s rejected by application", self.reply_channel) logger.debug("WebSocket %s rejected by application", self.client_addr)
self.factory.log_action("websocket", "rejected", { self.server.log_action("websocket", "rejected", {
"path": self.request.path, "path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
}) })
@ -191,47 +200,30 @@ class WebSocketProtocol(WebSocketServerProtocol):
if self.state == self.STATE_CONNECTING: if self.state == self.STATE_CONNECTING:
self.serverAccept() self.serverAccept()
self.last_data = time.time() self.last_data = time.time()
logger.debug("Sent WebSocket packet to client for %s", self.reply_channel) logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
if binary: if binary:
self.sendMessage(content, binary) self.sendMessage(content, binary)
else: else:
self.sendMessage(content.encode("utf8"), binary) self.sendMessage(content.encode("utf8"), binary)
def serverClose(self, code=True): def serverClose(self, code=None):
""" """
Server-side channel message to close the socket Server-side channel message to close the socket
""" """
code = 1000 if code is True else code code = 1000 if code is None else code
self.sendClose(code=code) self.sendClose(code=code)
def onClose(self, wasClean, code, reason): ### Utils
self.cleanup()
if hasattr(self, "reply_channel"):
logger.debug("WebSocket closed for %s", self.reply_channel)
try:
if not self.muted:
self.channel_layer.send("websocket.disconnect", {
"reply_channel": self.reply_channel,
"code": code,
"path": self.unquote(self.path),
"order": self.packets_received + 1,
})
except self.channel_layer.ChannelFull:
pass
self.factory.log_action("websocket", "disconnected", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
else:
logger.debug("WebSocket closed before handshake established")
def cleanup(self): @classmethod
def unquote(cls, value):
""" """
Call to clean up this socket after it's closed. Python 2 and 3 compat layer for utf-8 unquoting
""" """
if hasattr(self, "reply_channel"): if six.PY2:
if self.reply_channel in self.factory.reply_protocols: return unquote(value).decode("utf8")
del self.factory.reply_protocols[self.reply_channel] else:
return unquote(value.decode("ascii"))
def duration(self): def duration(self):
""" """
@ -275,3 +267,15 @@ class WebSocketFactory(WebSocketServerFactory):
def __init__(self, server_class, *args, **kwargs): def __init__(self, server_class, *args, **kwargs):
self.server_class = server_class self.server_class = server_class
WebSocketServerFactory.__init__(self, *args, **kwargs) WebSocketServerFactory.__init__(self, *args, **kwargs)
def buildProtocol(self, addr):
"""
Builds protocol instances. We use this to inject the factory object into the protocol.
"""
try:
protocol = super(WebSocketFactory, self).buildProtocol(addr)
protocol.factory = self
return protocol
except Exception as e:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise

View File

@ -24,7 +24,7 @@ setup(
include_package_data=True, include_package_data=True,
install_requires=[ install_requires=[
'asgiref~=1.1', 'asgiref~=1.1',
'twisted>=17.1', 'twisted>=17.5',
'autobahn>=0.18', 'autobahn>=0.18',
], ],
extras_require={ extras_require={