mirror of
https://github.com/django/daphne.git
synced 2024-11-21 23:46:33 +03:00
Trying out asyncio based interface
This commit is contained in:
parent
a656c9f4c6
commit
01f174bf26
|
@ -131,12 +131,8 @@ class CommandLineInterface(object):
|
|||
default=2,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
'channel_layer',
|
||||
help='The ASGI channel layer instance to use as path.to.module:instance.path',
|
||||
)
|
||||
self.parser.add_argument(
|
||||
'consumer',
|
||||
help='The consumer to dispatch to as path.to.module:instance.path',
|
||||
'application',
|
||||
help='The application to dispatch to as path.to.module:instance.path',
|
||||
)
|
||||
|
||||
self.server = None
|
||||
|
@ -161,6 +157,7 @@ class CommandLineInterface(object):
|
|||
0: logging.WARN,
|
||||
1: logging.INFO,
|
||||
2: logging.DEBUG,
|
||||
3: logging.DEBUG, # Also turns on asyncio debug
|
||||
}[args.verbosity],
|
||||
format="%(asctime)-15s %(levelname)-8s %(message)s",
|
||||
)
|
||||
|
@ -173,10 +170,9 @@ class CommandLineInterface(object):
|
|||
access_log_stream = open(args.access_log, "a", 1)
|
||||
elif args.verbosity >= 1:
|
||||
access_log_stream = sys.stdout
|
||||
# Import application and channel_layer
|
||||
# Import application
|
||||
sys.path.insert(0, ".")
|
||||
channel_layer = import_by_path(args.channel_layer)
|
||||
consumer = import_by_path(args.consumer)
|
||||
application = import_by_path(args.application)
|
||||
# Set up port/host bindings
|
||||
if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]):
|
||||
# no advanced binding options passed, patch in defaults
|
||||
|
@ -198,12 +194,11 @@ class CommandLineInterface(object):
|
|||
)
|
||||
# Start the server
|
||||
logger.info(
|
||||
'Starting server at %s, channel layer %s.' %
|
||||
(', '.join(endpoints), args.channel_layer)
|
||||
'Starting server at %s' %
|
||||
(', '.join(endpoints), )
|
||||
)
|
||||
self.server = Server(
|
||||
channel_layer=channel_layer,
|
||||
consumer=consumer,
|
||||
application=application,
|
||||
endpoints=endpoints,
|
||||
http_timeout=args.http_timeout,
|
||||
ping_interval=args.ping_interval,
|
||||
|
|
|
@ -10,6 +10,7 @@ import traceback
|
|||
from zope.interface import implementer
|
||||
|
||||
from six.moves.urllib_parse import unquote, unquote_plus
|
||||
from twisted.internet.defer import ensureDeferred
|
||||
from twisted.internet.interfaces import IProtocolNegotiationFactory
|
||||
from twisted.protocols.policies import ProtocolWrapper
|
||||
from twisted.web import http
|
||||
|
@ -29,7 +30,7 @@ class WebRequest(http.Request):
|
|||
GET and POST out.
|
||||
"""
|
||||
|
||||
consumer_type = "http"
|
||||
application_type = "http"
|
||||
|
||||
error_template = """
|
||||
<html>
|
||||
|
@ -55,7 +56,7 @@ class WebRequest(http.Request):
|
|||
http.Request.__init__(self, *args, **kwargs)
|
||||
# Easy server link
|
||||
self.server = self.channel.factory.server
|
||||
self.consumer_channel = None
|
||||
self.application_queue = None
|
||||
self._response_started = False
|
||||
self.server.add_protocol(self)
|
||||
except Exception:
|
||||
|
@ -137,8 +138,8 @@ class WebRequest(http.Request):
|
|||
|
||||
# Boring old HTTP.
|
||||
else:
|
||||
# Create consumer to handle this connection
|
||||
self.consumer_channel = self.server.create_consumer(self)
|
||||
# Create application to handle this connection
|
||||
self.application_queue = self.server.create_application(self)
|
||||
# Sanitize and decode headers, potentially extracting root path
|
||||
self.clean_headers = []
|
||||
self.root_path = self.server.root_path
|
||||
|
@ -151,11 +152,10 @@ class WebRequest(http.Request):
|
|||
self.root_path = self.unquote(value)
|
||||
else:
|
||||
self.clean_headers.append((name.lower(), value))
|
||||
logger.debug("HTTP %s request for %s", self.method, self.consumer_channel)
|
||||
logger.debug("HTTP %s request for %s", self.method, self.client_addr)
|
||||
self.content.seek(0, 0)
|
||||
# Run consumer
|
||||
self.server.handle_message(
|
||||
self.consumer_channel,
|
||||
# Run application against request
|
||||
self.application_queue.put_nowait(
|
||||
{
|
||||
"type": "http.request",
|
||||
# TODO: Correctly say if it's 1.1 or 1.0
|
||||
|
@ -173,13 +173,13 @@ class WebRequest(http.Request):
|
|||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
self.basic_error(500, b"Internal Server Error", "HTTP processing error")
|
||||
self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error")
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Cleans up reply channel on close.
|
||||
"""
|
||||
if self.consumer_channel:
|
||||
if self.application_queue:
|
||||
self.send_disconnect()
|
||||
logger.debug("HTTP disconnect for %s", self.client_addr)
|
||||
http.Request.connectionLost(self, reason)
|
||||
|
@ -189,7 +189,7 @@ class WebRequest(http.Request):
|
|||
"""
|
||||
Cleans up reply channel on close.
|
||||
"""
|
||||
if self.consumer_channel:
|
||||
if self.application_queue:
|
||||
self.send_disconnect()
|
||||
logger.debug("HTTP close for %s", self.client_addr)
|
||||
http.Request.finish(self)
|
||||
|
@ -201,6 +201,8 @@ class WebRequest(http.Request):
|
|||
"""
|
||||
Handles a reply from the client
|
||||
"""
|
||||
if "type" not in message:
|
||||
raise ValueError("Message has no type defined")
|
||||
if message['type'] == "http.response":
|
||||
if self._response_started:
|
||||
raise ValueError("HTTP response has already been started")
|
||||
|
@ -216,9 +218,9 @@ class WebRequest(http.Request):
|
|||
header = header.encode("latin1")
|
||||
self.responseHeaders.addRawHeader(header, value)
|
||||
logger.debug("HTTP %s response started for %s", message['status'], self.client_addr)
|
||||
elif message['type'] == "http.response.content":
|
||||
elif message['type'] == "http.response.hunk":
|
||||
if not self._response_started:
|
||||
raise ValueError("HTTP response has not yet been started")
|
||||
raise ValueError("HTTP response has not yet been started but got %s" % message['type'])
|
||||
else:
|
||||
raise ValueError("Cannot handle message type %s!" % message['type'])
|
||||
|
||||
|
@ -243,6 +245,12 @@ class WebRequest(http.Request):
|
|||
else:
|
||||
logger.debug("HTTP response chunk for %s", self.client_addr)
|
||||
|
||||
def handle_exception(self, exception):
|
||||
"""
|
||||
Called by the server when our application tracebacks
|
||||
"""
|
||||
self.basic_error(500, b"Internal Server Error", "Exception inside application.")
|
||||
|
||||
def check_timeouts(self):
|
||||
"""
|
||||
Called periodically to see if we should timeout something
|
||||
|
@ -276,8 +284,7 @@ class WebRequest(http.Request):
|
|||
"""
|
||||
# If we don't yet have a path, then don't send as we never opened.
|
||||
if self.path:
|
||||
self.server.handle_message(
|
||||
self.consumer_channel,
|
||||
self.application_queue.put_nowait(
|
||||
{
|
||||
"type": "http.disconnect",
|
||||
},
|
||||
|
@ -295,7 +302,8 @@ class WebRequest(http.Request):
|
|||
"""
|
||||
Responds with a server-level error page (very basic)
|
||||
"""
|
||||
self.serverResponse({
|
||||
self.handle_reply({
|
||||
"type": "http.response",
|
||||
"status": status,
|
||||
"status_text": status_text,
|
||||
"headers": [
|
||||
|
@ -340,83 +348,6 @@ class HTTPFactory(http.HTTPFactory):
|
|||
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
||||
raise
|
||||
|
||||
def make_send_channel(self):
|
||||
"""
|
||||
Makes a new send channel for a protocol with our process prefix.
|
||||
"""
|
||||
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
|
||||
return self.send_channel + protocol_id
|
||||
|
||||
def dispatch_reply(self, channel, message):
|
||||
# If we don't know about the channel, ignore it (likely a channel we
|
||||
# used to have that's now in a group).
|
||||
# TODO: Find a better way of alerting people when this happens so
|
||||
# they can do more cleanup, that's not an error.
|
||||
if channel not in self.reply_protocols:
|
||||
logger.debug("Message on unknown channel %r - discarding" % channel)
|
||||
return
|
||||
|
||||
if isinstance(self.reply_protocols[channel], WebRequest):
|
||||
self.reply_protocols[channel].serverResponse(message)
|
||||
elif isinstance(self.reply_protocols[channel], WebSocketProtocol):
|
||||
# Switch depending on current socket state
|
||||
protocol = self.reply_protocols[channel]
|
||||
# See if the message is valid
|
||||
unknown_keys = set(message.keys()) - {"bytes", "text", "close", "accept"}
|
||||
if unknown_keys:
|
||||
raise ValueError(
|
||||
"Got invalid WebSocket reply message on %s - "
|
||||
"contains unknown keys %s (looking for either {'accept', 'text', 'bytes', 'close'})" % (
|
||||
channel,
|
||||
unknown_keys,
|
||||
)
|
||||
)
|
||||
# Accepts allow bytes/text afterwards
|
||||
if message.get("accept", None) and protocol.state == protocol.STATE_CONNECTING:
|
||||
protocol.serverAccept()
|
||||
# Rejections must be the only thing
|
||||
if message.get("accept", None) == False and protocol.state == protocol.STATE_CONNECTING:
|
||||
protocol.serverReject()
|
||||
return
|
||||
# You're only allowed one of bytes or text
|
||||
if message.get("bytes", None) and message.get("text", None):
|
||||
raise ValueError(
|
||||
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
|
||||
channel,
|
||||
)
|
||||
)
|
||||
if message.get("bytes", None):
|
||||
protocol.serverSend(message["bytes"], True)
|
||||
if message.get("text", None):
|
||||
protocol.serverSend(message["text"], False)
|
||||
|
||||
closing_code = message.get("close", False)
|
||||
if closing_code:
|
||||
if protocol.state == protocol.STATE_CONNECTING:
|
||||
protocol.serverReject()
|
||||
else:
|
||||
protocol.serverClose(code=closing_code)
|
||||
else:
|
||||
raise ValueError("Unknown protocol class")
|
||||
|
||||
def check_timeouts(self):
|
||||
"""
|
||||
Runs through all HTTP protocol instances and times them out if they've
|
||||
taken too long (and so their message is probably expired)
|
||||
"""
|
||||
for protocol in list(self.reply_protocols.values()):
|
||||
# Web timeout checking
|
||||
if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout:
|
||||
protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.")
|
||||
# WebSocket timeout checking and keepalive ping sending
|
||||
elif isinstance(protocol, WebSocketProtocol):
|
||||
# Timeout check
|
||||
if protocol.duration() > self.websocket_timeout and self.websocket_timeout >= 0:
|
||||
protocol.serverClose()
|
||||
# Ping check
|
||||
else:
|
||||
protocol.check_ping()
|
||||
|
||||
# IProtocolNegotiationFactory
|
||||
def acceptableProtocols(self):
|
||||
"""
|
||||
|
|
178
daphne/server.py
178
daphne/server.py
|
@ -1,16 +1,21 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
# This has to be done first as Twisted is import-order-sensitive with reactors
|
||||
from twisted.internet import asyncioreactor
|
||||
asyncioreactor.install()
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.internet.endpoints import serverFromString
|
||||
from twisted.logger import globalLogBeginner, STDLibLogObserver
|
||||
from twisted.web import http
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
|
||||
from .http_protocol import HTTPFactory
|
||||
from .ws_protocol import WebSocketFactory
|
||||
|
@ -22,13 +27,12 @@ class Server(object):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
channel_layer,
|
||||
consumer,
|
||||
application,
|
||||
endpoints=None,
|
||||
signal_handlers=True,
|
||||
action_logger=None,
|
||||
http_timeout=120,
|
||||
websocket_timeout=None,
|
||||
websocket_timeout=86400,
|
||||
websocket_connect_timeout=20,
|
||||
ping_interval=20,
|
||||
ping_timeout=30,
|
||||
|
@ -39,10 +43,9 @@ class Server(object):
|
|||
verbosity=1,
|
||||
websocket_handshake_timeout=5
|
||||
):
|
||||
self.channel_layer = channel_layer
|
||||
self.consumer = consumer
|
||||
self.application = application
|
||||
self.endpoints = endpoints or []
|
||||
if len(self.endpoints) == 0:
|
||||
if not self.endpoints:
|
||||
raise UserWarning("No endpoints. This server will not listen on anything.")
|
||||
self.listeners = []
|
||||
self.signal_handlers = signal_handlers
|
||||
|
@ -52,30 +55,20 @@ class Server(object):
|
|||
self.ping_timeout = ping_timeout
|
||||
self.proxy_forwarded_address_header = proxy_forwarded_address_header
|
||||
self.proxy_forwarded_port_header = proxy_forwarded_port_header
|
||||
# If they did not provide a websocket timeout, default it to the
|
||||
# channel layer's group_expiry value if present, or one day if not.
|
||||
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
|
||||
self.websocket_timeout = websocket_timeout
|
||||
self.websocket_connect_timeout = websocket_connect_timeout
|
||||
self.websocket_handshake_timeout = websocket_handshake_timeout
|
||||
self.ws_protocols = ws_protocols
|
||||
self.websocket_protocols = ws_protocols
|
||||
self.root_path = root_path
|
||||
self.verbosity = verbosity
|
||||
|
||||
def run(self):
|
||||
# Make the thread pool to run consumers in
|
||||
# TODO: Configurable numbers of threads
|
||||
self.pool = ThreadPool(name="consumers")
|
||||
# Make the mapping of consumer instances to consumer channels
|
||||
self.consumer_instances = {}
|
||||
# A set of current Twisted protocol instances to manage
|
||||
self.protocols = set()
|
||||
# Create process-local channel prefixes
|
||||
# TODO: Can we guarantee non-collision better?
|
||||
process_id = "".join(random.choice(string.ascii_letters) for i in range(10))
|
||||
self.consumer_channel_prefix = "daphne.%s!" % process_id
|
||||
self.application_instances = {}
|
||||
# Make the factory
|
||||
self.http_factory = HTTPFactory(self)
|
||||
self.ws_factory = WebSocketFactory(self, protocols=self.ws_protocols, server='Daphne')
|
||||
self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server='Daphne')
|
||||
self.ws_factory.setProtocolOptions(
|
||||
autoPingTimeout=self.ping_timeout,
|
||||
allowNullOrigin=True,
|
||||
|
@ -93,18 +86,24 @@ class Server(object):
|
|||
else:
|
||||
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
|
||||
|
||||
# Kick off the various background loops
|
||||
reactor.callLater(0, self.backend_reader)
|
||||
# Kick off the timeout loop
|
||||
reactor.callLater(1, self.application_checker)
|
||||
reactor.callLater(2, self.timeout_checker)
|
||||
|
||||
for socket_description in self.endpoints:
|
||||
logger.info("Listening on endpoint %s" % socket_description)
|
||||
# Twisted requires str on python2 (not unicode) and str on python3 (not bytes)
|
||||
ep = serverFromString(reactor, str(socket_description))
|
||||
self.listeners.append(ep.listen(self.http_factory))
|
||||
|
||||
self.pool.start()
|
||||
reactor.addSystemEventTrigger("before", "shutdown", self.pool.stop)
|
||||
# Set the asyncio reactor's event loop as global
|
||||
# TODO: Should we instead pass the global one into the reactor?
|
||||
asyncio.set_event_loop(reactor._asyncioEventloop)
|
||||
|
||||
# Verbosity 3 turns on asyncio debug to find those blocking yields
|
||||
if self.verbosity >= 3:
|
||||
asyncio.get_event_loop().set_debug(True)
|
||||
|
||||
reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications)
|
||||
reactor.run(installSignalHandlers=self.signal_handlers)
|
||||
|
||||
### Protocol handling
|
||||
|
@ -115,83 +114,86 @@ class Server(object):
|
|||
self.protocols.add(protocol)
|
||||
|
||||
def discard_protocol(self, protocol):
|
||||
# Ensure it's not in the protocol-tracking set
|
||||
self.protocols.discard(protocol)
|
||||
# Make sure any application future that's running is cancelled
|
||||
if protocol in self.application_instances:
|
||||
self.application_instances[protocol].cancel()
|
||||
del self.application_instances[protocol]
|
||||
|
||||
### Internal event/message handling
|
||||
|
||||
def create_consumer(self, protocol):
|
||||
def create_application(self, protocol):
|
||||
"""
|
||||
Creates a new consumer instance that fronts a Protocol instance
|
||||
Creates a new application instance that fronts a Protocol instance
|
||||
for one of our supported protocols. Pass it the protocol,
|
||||
and it will work out the type, supply appropriate callables, and
|
||||
put it into the server's consumer pool.
|
||||
|
||||
It returns the consumer channel name, which is how you should refer
|
||||
to the consumer instance.
|
||||
return you the application's input queue
|
||||
"""
|
||||
# Make sure the protocol defines a consumer type
|
||||
assert protocol.consumer_type is not None
|
||||
# Make it a consumer channel name
|
||||
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
|
||||
consumer_channel = self.consumer_channel_prefix + protocol_id
|
||||
# Make an instance of the consumer
|
||||
consumer_instance = self.consumer(
|
||||
type=protocol.consumer_type,
|
||||
# Make sure the protocol defines a application type
|
||||
assert protocol.application_type is not None
|
||||
# Make sure the protocol has not had another application made for it
|
||||
assert protocol not in self.application_instances
|
||||
# Make an instance of the application
|
||||
input_queue = asyncio.Queue()
|
||||
application_instance = asyncio.ensure_future(self.application(
|
||||
type=protocol.application_type,
|
||||
next=input_queue.get,
|
||||
reply=lambda message: self.handle_reply(protocol, message),
|
||||
channel_layer=self.channel_layer,
|
||||
consumer_channel=consumer_channel,
|
||||
)
|
||||
# Assign it by channel and return it
|
||||
self.consumer_instances[consumer_channel] = consumer_instance
|
||||
return consumer_channel
|
||||
), loop=asyncio.get_event_loop())
|
||||
self.application_instances[protocol] = application_instance
|
||||
return input_queue
|
||||
|
||||
def handle_message(self, consumer_channel, message):
|
||||
async def handle_reply(self, protocol, message):
|
||||
"""
|
||||
Schedules the application instance to handle the given message.
|
||||
Coroutine that jumps the reply message from asyncio to Twisted
|
||||
"""
|
||||
self.pool.callInThread(self.consumer_instances[consumer_channel], message)
|
||||
|
||||
def handle_reply(self, protocol, message):
|
||||
"""
|
||||
Schedules the reply to be handled by the protocol in the main thread
|
||||
"""
|
||||
reactor.callFromThread(reactor.callLater, 0, protocol.handle_reply, message)
|
||||
|
||||
### External event/message handling
|
||||
|
||||
def backend_reader(self):
|
||||
"""
|
||||
Runs as an-often-as-possible task with the reactor, unless there was
|
||||
no result previously in which case we add a small delay.
|
||||
"""
|
||||
channels = [self.consumer_channel_prefix]
|
||||
delay = 0
|
||||
# Quit if reactor is stopping
|
||||
if not reactor.running:
|
||||
logger.debug("Backend reader quitting due to reactor stop")
|
||||
return
|
||||
# Try to receive a message
|
||||
try:
|
||||
channel, message = self.channel_layer.receive(channels, block=False)
|
||||
except Exception as e:
|
||||
# Log the error and wait a bit to retry
|
||||
logger.error('Error trying to receive messages: %s' % e)
|
||||
delay = 5.00
|
||||
else:
|
||||
if channel:
|
||||
# Deal with the message
|
||||
try:
|
||||
self.handle_message(channel, message)
|
||||
except Exception as e:
|
||||
logger.error("Error handling external message: %s" % e)
|
||||
else:
|
||||
# If there's no messages, idle a little bit.
|
||||
delay = 0.05
|
||||
# We can't loop inside here as this is synchronous code.
|
||||
reactor.callLater(delay, self.backend_reader)
|
||||
reactor.callLater(0, protocol.handle_reply, message)
|
||||
|
||||
### Utility
|
||||
|
||||
def application_checker(self):
|
||||
"""
|
||||
Goes through the set of current application Futures and cleans up
|
||||
any that are done/prints exceptions for any that errored.
|
||||
"""
|
||||
for protocol, application_instance in list(self.application_instances.items()):
|
||||
if application_instance.done():
|
||||
exception = application_instance.exception()
|
||||
if exception:
|
||||
logging.error(
|
||||
"Exception inside application: {}\n{}{}".format(
|
||||
exception,
|
||||
"".join(traceback.format_tb(
|
||||
exception.__traceback__,
|
||||
)),
|
||||
" {}".format(exception),
|
||||
)
|
||||
)
|
||||
protocol.handle_exception(exception)
|
||||
try:
|
||||
del self.application_instances[protocol]
|
||||
except KeyError:
|
||||
# The protocol might have already got here before us. That's fine.
|
||||
pass
|
||||
reactor.callLater(1, self.application_checker)
|
||||
|
||||
def kill_all_applications(self):
|
||||
"""
|
||||
Kills all application coroutines before reactor exit.
|
||||
"""
|
||||
# Send cancel to all coroutines
|
||||
wait_for = []
|
||||
for application_instance in self.application_instances.values():
|
||||
if not application_instance.done():
|
||||
application_instance.cancel()
|
||||
wait_for.append(application_instance)
|
||||
logging.info("Killed %i pending application instances" % len(wait_for))
|
||||
# Make Twisted wait until they're all dead
|
||||
wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for))
|
||||
wait_deferred.addErrback(lambda x: None)
|
||||
return wait_deferred
|
||||
|
||||
def timeout_checker(self):
|
||||
"""
|
||||
Called periodically to enforce timeout rules on all connections.
|
||||
|
|
|
@ -20,20 +20,21 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
the websocket channels.
|
||||
"""
|
||||
|
||||
application_type = "websocket"
|
||||
|
||||
# If we should send no more messages (e.g. we error-closed the socket)
|
||||
muted = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(WebSocketProtocol, self).__init__(*args, **kwargs)
|
||||
self.server = self.factory.server_class
|
||||
|
||||
def onConnect(self, request):
|
||||
self.server = self.factory.server_class
|
||||
self.server.add_protocol(self)
|
||||
self.request = request
|
||||
self.packets_received = 0
|
||||
self.protocol_to_accept = None
|
||||
self.socket_opened = time.time()
|
||||
self.last_data = time.time()
|
||||
try:
|
||||
# Make new application instance
|
||||
self.application_queue = self.server.create_application(self)
|
||||
# Sanitize and decode headers
|
||||
self.clean_headers = []
|
||||
for name, value in request.headers.items():
|
||||
|
@ -42,10 +43,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
if b"_" in name:
|
||||
continue
|
||||
self.clean_headers.append((name.lower(), value.encode("latin1")))
|
||||
# Make sending channel
|
||||
self.reply_channel = self.main_factory.make_send_channel()
|
||||
# Tell main factory about it
|
||||
self.main_factory.reply_protocols[self.reply_channel] = self
|
||||
# Get client address if possible
|
||||
peer = self.transport.getPeer()
|
||||
host = self.transport.getHost()
|
||||
|
@ -56,23 +53,23 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.client_addr = None
|
||||
self.server_addr = None
|
||||
|
||||
if self.main_factory.proxy_forwarded_address_header:
|
||||
if self.server.proxy_forwarded_address_header:
|
||||
self.client_addr = parse_x_forwarded_for(
|
||||
self.http_headers,
|
||||
self.main_factory.proxy_forwarded_address_header,
|
||||
self.main_factory.proxy_forwarded_port_header,
|
||||
self.server.proxy_forwarded_address_header,
|
||||
self.server.proxy_forwarded_port_header,
|
||||
self.client_addr
|
||||
)
|
||||
|
||||
# Make initial request info dict from request (we only have it here)
|
||||
self.path = request.path.encode("ascii")
|
||||
self.request_info = {
|
||||
self.connect_message = {
|
||||
"type": "websocket.connect",
|
||||
"path": self.unquote(self.path),
|
||||
"headers": self.clean_headers,
|
||||
"query_string": self._raw_query_string, # Passed by HTTP protocol
|
||||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
"reply_channel": self.reply_channel,
|
||||
"order": 0,
|
||||
}
|
||||
except:
|
||||
|
@ -86,28 +83,19 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
if header == b'sec-websocket-protocol':
|
||||
protocols = [x.strip() for x in self.unquote(value).split(",")]
|
||||
for protocol in protocols:
|
||||
if protocol in self.factory.protocols:
|
||||
if protocol in self.server.websocket_protocols:
|
||||
ws_protocol = protocol
|
||||
break
|
||||
|
||||
# Work out what subprotocol we will accept, if any
|
||||
if ws_protocol and ws_protocol in self.factory.protocols:
|
||||
if ws_protocol and ws_protocol in self.server.websocket_protocols:
|
||||
self.protocol_to_accept = ws_protocol
|
||||
else:
|
||||
self.protocol_to_accept = None
|
||||
|
||||
# Send over the connect message
|
||||
try:
|
||||
self.channel_layer.send("websocket.connect", self.request_info)
|
||||
except self.channel_layer.ChannelFull:
|
||||
# You have to consume websocket.connect according to the spec,
|
||||
# so drop the connection.
|
||||
self.muted = True
|
||||
logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel)
|
||||
# Send code 503 "Service Unavailable" with close.
|
||||
raise ConnectionDeny(code=503, reason="Connection queue at capacity")
|
||||
else:
|
||||
self.factory.log_action("websocket", "connecting", {
|
||||
self.application_queue.put_nowait(self.connect_message)
|
||||
self.server.log_action("websocket", "connecting", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
|
@ -116,20 +104,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.handshake_deferred = defer.Deferred()
|
||||
return self.handshake_deferred
|
||||
|
||||
@classmethod
|
||||
def unquote(cls, value):
|
||||
"""
|
||||
Python 2 and 3 compat layer for utf-8 unquoting
|
||||
"""
|
||||
if six.PY2:
|
||||
return unquote(value).decode("utf8")
|
||||
else:
|
||||
return unquote(value.decode("ascii"))
|
||||
### Twisted event handling
|
||||
|
||||
def onOpen(self):
|
||||
# Send news that this channel is open
|
||||
logger.debug("WebSocket %s open and established", self.reply_channel)
|
||||
self.factory.log_action("websocket", "connected", {
|
||||
logger.debug("WebSocket %s open and established", self.client_addr)
|
||||
self.server.log_action("websocket", "connected", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
|
@ -137,49 +117,78 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
def onMessage(self, payload, isBinary):
|
||||
# If we're muted, do nothing.
|
||||
if self.muted:
|
||||
logger.debug("Muting incoming frame on %s", self.reply_channel)
|
||||
logger.debug("Muting incoming frame on %s", self.client_addr)
|
||||
return
|
||||
logger.debug("WebSocket incoming frame on %s", self.reply_channel)
|
||||
self.packets_received += 1
|
||||
logger.debug("WebSocket incoming frame on %s", self.client_addr)
|
||||
self.last_data = time.time()
|
||||
try:
|
||||
if isBinary:
|
||||
self.channel_layer.send("websocket.receive", {
|
||||
"reply_channel": self.reply_channel,
|
||||
"path": self.unquote(self.path),
|
||||
"order": self.packets_received,
|
||||
self.application_queue.put_nowait({
|
||||
"type": "websocket.receive",
|
||||
"bytes": payload,
|
||||
})
|
||||
else:
|
||||
self.channel_layer.send("websocket.receive", {
|
||||
"reply_channel": self.reply_channel,
|
||||
"path": self.unquote(self.path),
|
||||
"order": self.packets_received,
|
||||
self.application_queue.put_nowait({
|
||||
"type": "websocket.receive",
|
||||
"text": payload.decode("utf8"),
|
||||
})
|
||||
except self.channel_layer.ChannelFull:
|
||||
# You have to consume websocket.receive according to the spec,
|
||||
# so drop the connection.
|
||||
self.muted = True
|
||||
logger.warn("WebSocket force closed for %s due to receive backpressure", self.reply_channel)
|
||||
# Send code 1013 "try again later" with close.
|
||||
self.sendCloseFrame(code=1013, isReply=False)
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
"""
|
||||
Called when Twisted closes the socket.
|
||||
"""
|
||||
self.server.discard_protocol(self)
|
||||
logger.debug("WebSocket closed for %s", self.client_addr)
|
||||
if not self.muted:
|
||||
self.application_queue.put_nowait({
|
||||
"type": "websocket.disconnect",
|
||||
"code": code,
|
||||
})
|
||||
self.server.log_action("websocket", "disconnected", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
|
||||
### Internal event handling
|
||||
|
||||
def handle_reply(self, message):
|
||||
if "type" not in message:
|
||||
raise ValueError("Message has no type defined")
|
||||
if message["type"] == "websocket.accept":
|
||||
self.serverAccept()
|
||||
elif message["type"] == "websocket.close":
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
self.serverReject()
|
||||
else:
|
||||
self.serverClose(code=message.get("code", None))
|
||||
elif message["type"] == "websocket.send":
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
raise ValueError("Socket has not been accepted, so cannot send over it")
|
||||
if message.get("bytes", None) and message.get("text", None):
|
||||
raise ValueError(
|
||||
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
|
||||
channel,
|
||||
)
|
||||
)
|
||||
if message.get("bytes", None):
|
||||
self.serverSend(message["bytes"], True)
|
||||
if message.get("text", None):
|
||||
self.serverSend(message["text"], False)
|
||||
|
||||
def serverAccept(self):
|
||||
"""
|
||||
Called when we get a message saying to accept the connection.
|
||||
"""
|
||||
self.handshake_deferred.callback(self.protocol_to_accept)
|
||||
logger.debug("WebSocket %s accepted by application", self.reply_channel)
|
||||
logger.debug("WebSocket %s accepted by application", self.client_addr)
|
||||
|
||||
def serverReject(self):
|
||||
"""
|
||||
Called when we get a message saying to reject the connection.
|
||||
"""
|
||||
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
|
||||
self.cleanup()
|
||||
logger.debug("WebSocket %s rejected by application", self.reply_channel)
|
||||
self.factory.log_action("websocket", "rejected", {
|
||||
self.server.discard_protocol(self)
|
||||
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
||||
self.server.log_action("websocket", "rejected", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
|
@ -191,47 +200,30 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
if self.state == self.STATE_CONNECTING:
|
||||
self.serverAccept()
|
||||
self.last_data = time.time()
|
||||
logger.debug("Sent WebSocket packet to client for %s", self.reply_channel)
|
||||
logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
|
||||
if binary:
|
||||
self.sendMessage(content, binary)
|
||||
else:
|
||||
self.sendMessage(content.encode("utf8"), binary)
|
||||
|
||||
def serverClose(self, code=True):
|
||||
def serverClose(self, code=None):
|
||||
"""
|
||||
Server-side channel message to close the socket
|
||||
"""
|
||||
code = 1000 if code is True else code
|
||||
code = 1000 if code is None else code
|
||||
self.sendClose(code=code)
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
self.cleanup()
|
||||
if hasattr(self, "reply_channel"):
|
||||
logger.debug("WebSocket closed for %s", self.reply_channel)
|
||||
try:
|
||||
if not self.muted:
|
||||
self.channel_layer.send("websocket.disconnect", {
|
||||
"reply_channel": self.reply_channel,
|
||||
"code": code,
|
||||
"path": self.unquote(self.path),
|
||||
"order": self.packets_received + 1,
|
||||
})
|
||||
except self.channel_layer.ChannelFull:
|
||||
pass
|
||||
self.factory.log_action("websocket", "disconnected", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
else:
|
||||
logger.debug("WebSocket closed before handshake established")
|
||||
### Utils
|
||||
|
||||
def cleanup(self):
|
||||
@classmethod
|
||||
def unquote(cls, value):
|
||||
"""
|
||||
Call to clean up this socket after it's closed.
|
||||
Python 2 and 3 compat layer for utf-8 unquoting
|
||||
"""
|
||||
if hasattr(self, "reply_channel"):
|
||||
if self.reply_channel in self.factory.reply_protocols:
|
||||
del self.factory.reply_protocols[self.reply_channel]
|
||||
if six.PY2:
|
||||
return unquote(value).decode("utf8")
|
||||
else:
|
||||
return unquote(value.decode("ascii"))
|
||||
|
||||
def duration(self):
|
||||
"""
|
||||
|
@ -275,3 +267,15 @@ class WebSocketFactory(WebSocketServerFactory):
|
|||
def __init__(self, server_class, *args, **kwargs):
|
||||
self.server_class = server_class
|
||||
WebSocketServerFactory.__init__(self, *args, **kwargs)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
"""
|
||||
Builds protocol instances. We use this to inject the factory object into the protocol.
|
||||
"""
|
||||
try:
|
||||
protocol = super(WebSocketFactory, self).buildProtocol(addr)
|
||||
protocol.factory = self
|
||||
return protocol
|
||||
except Exception as e:
|
||||
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
||||
raise
|
||||
|
|
Loading…
Reference in New Issue
Block a user