Trying out asyncio based interface

This commit is contained in:
Andrew Godwin 2017-09-07 21:24:14 -07:00
parent a656c9f4c6
commit 01f174bf26
5 changed files with 224 additions and 292 deletions

View File

@ -131,12 +131,8 @@ class CommandLineInterface(object):
default=2,
)
self.parser.add_argument(
'channel_layer',
help='The ASGI channel layer instance to use as path.to.module:instance.path',
)
self.parser.add_argument(
'consumer',
help='The consumer to dispatch to as path.to.module:instance.path',
'application',
help='The application to dispatch to as path.to.module:instance.path',
)
self.server = None
@ -161,6 +157,7 @@ class CommandLineInterface(object):
0: logging.WARN,
1: logging.INFO,
2: logging.DEBUG,
3: logging.DEBUG, # Also turns on asyncio debug
}[args.verbosity],
format="%(asctime)-15s %(levelname)-8s %(message)s",
)
@ -173,10 +170,9 @@ class CommandLineInterface(object):
access_log_stream = open(args.access_log, "a", 1)
elif args.verbosity >= 1:
access_log_stream = sys.stdout
# Import application and channel_layer
# Import application
sys.path.insert(0, ".")
channel_layer = import_by_path(args.channel_layer)
consumer = import_by_path(args.consumer)
application = import_by_path(args.application)
# Set up port/host bindings
if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]):
# no advanced binding options passed, patch in defaults
@ -198,12 +194,11 @@ class CommandLineInterface(object):
)
# Start the server
logger.info(
'Starting server at %s, channel layer %s.' %
(', '.join(endpoints), args.channel_layer)
'Starting server at %s' %
(', '.join(endpoints), )
)
self.server = Server(
channel_layer=channel_layer,
consumer=consumer,
application=application,
endpoints=endpoints,
http_timeout=args.http_timeout,
ping_interval=args.ping_interval,

View File

@ -10,6 +10,7 @@ import traceback
from zope.interface import implementer
from six.moves.urllib_parse import unquote, unquote_plus
from twisted.internet.defer import ensureDeferred
from twisted.internet.interfaces import IProtocolNegotiationFactory
from twisted.protocols.policies import ProtocolWrapper
from twisted.web import http
@ -29,7 +30,7 @@ class WebRequest(http.Request):
GET and POST out.
"""
consumer_type = "http"
application_type = "http"
error_template = """
<html>
@ -55,7 +56,7 @@ class WebRequest(http.Request):
http.Request.__init__(self, *args, **kwargs)
# Easy server link
self.server = self.channel.factory.server
self.consumer_channel = None
self.application_queue = None
self._response_started = False
self.server.add_protocol(self)
except Exception:
@ -137,8 +138,8 @@ class WebRequest(http.Request):
# Boring old HTTP.
else:
# Create consumer to handle this connection
self.consumer_channel = self.server.create_consumer(self)
# Create application to handle this connection
self.application_queue = self.server.create_application(self)
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
self.root_path = self.server.root_path
@ -151,11 +152,10 @@ class WebRequest(http.Request):
self.root_path = self.unquote(value)
else:
self.clean_headers.append((name.lower(), value))
logger.debug("HTTP %s request for %s", self.method, self.consumer_channel)
logger.debug("HTTP %s request for %s", self.method, self.client_addr)
self.content.seek(0, 0)
# Run consumer
self.server.handle_message(
self.consumer_channel,
# Run application against request
self.application_queue.put_nowait(
{
"type": "http.request",
# TODO: Correctly say if it's 1.1 or 1.0
@ -173,13 +173,13 @@ class WebRequest(http.Request):
)
except Exception:
logger.error(traceback.format_exc())
self.basic_error(500, b"Internal Server Error", "HTTP processing error")
self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error")
def connectionLost(self, reason):
"""
Cleans up reply channel on close.
"""
if self.consumer_channel:
if self.application_queue:
self.send_disconnect()
logger.debug("HTTP disconnect for %s", self.client_addr)
http.Request.connectionLost(self, reason)
@ -189,7 +189,7 @@ class WebRequest(http.Request):
"""
Cleans up reply channel on close.
"""
if self.consumer_channel:
if self.application_queue:
self.send_disconnect()
logger.debug("HTTP close for %s", self.client_addr)
http.Request.finish(self)
@ -201,6 +201,8 @@ class WebRequest(http.Request):
"""
Handles a reply from the client
"""
if "type" not in message:
raise ValueError("Message has no type defined")
if message['type'] == "http.response":
if self._response_started:
raise ValueError("HTTP response has already been started")
@ -216,9 +218,9 @@ class WebRequest(http.Request):
header = header.encode("latin1")
self.responseHeaders.addRawHeader(header, value)
logger.debug("HTTP %s response started for %s", message['status'], self.client_addr)
elif message['type'] == "http.response.content":
elif message['type'] == "http.response.hunk":
if not self._response_started:
raise ValueError("HTTP response has not yet been started")
raise ValueError("HTTP response has not yet been started but got %s" % message['type'])
else:
raise ValueError("Cannot handle message type %s!" % message['type'])
@ -243,6 +245,12 @@ class WebRequest(http.Request):
else:
logger.debug("HTTP response chunk for %s", self.client_addr)
def handle_exception(self, exception):
"""
Called by the server when our application tracebacks
"""
self.basic_error(500, b"Internal Server Error", "Exception inside application.")
def check_timeouts(self):
"""
Called periodically to see if we should timeout something
@ -276,8 +284,7 @@ class WebRequest(http.Request):
"""
# If we don't yet have a path, then don't send as we never opened.
if self.path:
self.server.handle_message(
self.consumer_channel,
self.application_queue.put_nowait(
{
"type": "http.disconnect",
},
@ -295,7 +302,8 @@ class WebRequest(http.Request):
"""
Responds with a server-level error page (very basic)
"""
self.serverResponse({
self.handle_reply({
"type": "http.response",
"status": status,
"status_text": status_text,
"headers": [
@ -340,83 +348,6 @@ class HTTPFactory(http.HTTPFactory):
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise
def make_send_channel(self):
"""
Makes a new send channel for a protocol with our process prefix.
"""
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
return self.send_channel + protocol_id
def dispatch_reply(self, channel, message):
# If we don't know about the channel, ignore it (likely a channel we
# used to have that's now in a group).
# TODO: Find a better way of alerting people when this happens so
# they can do more cleanup, that's not an error.
if channel not in self.reply_protocols:
logger.debug("Message on unknown channel %r - discarding" % channel)
return
if isinstance(self.reply_protocols[channel], WebRequest):
self.reply_protocols[channel].serverResponse(message)
elif isinstance(self.reply_protocols[channel], WebSocketProtocol):
# Switch depending on current socket state
protocol = self.reply_protocols[channel]
# See if the message is valid
unknown_keys = set(message.keys()) - {"bytes", "text", "close", "accept"}
if unknown_keys:
raise ValueError(
"Got invalid WebSocket reply message on %s - "
"contains unknown keys %s (looking for either {'accept', 'text', 'bytes', 'close'})" % (
channel,
unknown_keys,
)
)
# Accepts allow bytes/text afterwards
if message.get("accept", None) and protocol.state == protocol.STATE_CONNECTING:
protocol.serverAccept()
# Rejections must be the only thing
if message.get("accept", None) == False and protocol.state == protocol.STATE_CONNECTING:
protocol.serverReject()
return
# You're only allowed one of bytes or text
if message.get("bytes", None) and message.get("text", None):
raise ValueError(
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
channel,
)
)
if message.get("bytes", None):
protocol.serverSend(message["bytes"], True)
if message.get("text", None):
protocol.serverSend(message["text"], False)
closing_code = message.get("close", False)
if closing_code:
if protocol.state == protocol.STATE_CONNECTING:
protocol.serverReject()
else:
protocol.serverClose(code=closing_code)
else:
raise ValueError("Unknown protocol class")
def check_timeouts(self):
"""
Runs through all HTTP protocol instances and times them out if they've
taken too long (and so their message is probably expired)
"""
for protocol in list(self.reply_protocols.values()):
# Web timeout checking
if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout:
protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.")
# WebSocket timeout checking and keepalive ping sending
elif isinstance(protocol, WebSocketProtocol):
# Timeout check
if protocol.duration() > self.websocket_timeout and self.websocket_timeout >= 0:
protocol.serverClose()
# Ping check
else:
protocol.check_ping()
# IProtocolNegotiationFactory
def acceptableProtocols(self):
"""

View File

@ -1,16 +1,21 @@
from __future__ import unicode_literals
# This has to be done first as Twisted is import-order-sensitive with reactors
from twisted.internet import asyncioreactor
asyncioreactor.install()
import asyncio
import collections
import logging
import random
import string
import traceback
import warnings
from twisted.internet import reactor, defer
from twisted.internet.endpoints import serverFromString
from twisted.logger import globalLogBeginner, STDLibLogObserver
from twisted.web import http
from twisted.python.threadpool import ThreadPool
from .http_protocol import HTTPFactory
from .ws_protocol import WebSocketFactory
@ -22,13 +27,12 @@ class Server(object):
def __init__(
self,
channel_layer,
consumer,
application,
endpoints=None,
signal_handlers=True,
action_logger=None,
http_timeout=120,
websocket_timeout=None,
websocket_timeout=86400,
websocket_connect_timeout=20,
ping_interval=20,
ping_timeout=30,
@ -39,10 +43,9 @@ class Server(object):
verbosity=1,
websocket_handshake_timeout=5
):
self.channel_layer = channel_layer
self.consumer = consumer
self.application = application
self.endpoints = endpoints or []
if len(self.endpoints) == 0:
if not self.endpoints:
raise UserWarning("No endpoints. This server will not listen on anything.")
self.listeners = []
self.signal_handlers = signal_handlers
@ -52,30 +55,20 @@ class Server(object):
self.ping_timeout = ping_timeout
self.proxy_forwarded_address_header = proxy_forwarded_address_header
self.proxy_forwarded_port_header = proxy_forwarded_port_header
# If they did not provide a websocket timeout, default it to the
# channel layer's group_expiry value if present, or one day if not.
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout
self.websocket_handshake_timeout = websocket_handshake_timeout
self.ws_protocols = ws_protocols
self.websocket_protocols = ws_protocols
self.root_path = root_path
self.verbosity = verbosity
def run(self):
# Make the thread pool to run consumers in
# TODO: Configurable numbers of threads
self.pool = ThreadPool(name="consumers")
# Make the mapping of consumer instances to consumer channels
self.consumer_instances = {}
# A set of current Twisted protocol instances to manage
self.protocols = set()
# Create process-local channel prefixes
# TODO: Can we guarantee non-collision better?
process_id = "".join(random.choice(string.ascii_letters) for i in range(10))
self.consumer_channel_prefix = "daphne.%s!" % process_id
self.application_instances = {}
# Make the factory
self.http_factory = HTTPFactory(self)
self.ws_factory = WebSocketFactory(self, protocols=self.ws_protocols, server='Daphne')
self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server='Daphne')
self.ws_factory.setProtocolOptions(
autoPingTimeout=self.ping_timeout,
allowNullOrigin=True,
@ -93,18 +86,24 @@ class Server(object):
else:
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
# Kick off the various background loops
reactor.callLater(0, self.backend_reader)
# Kick off the timeout loop
reactor.callLater(1, self.application_checker)
reactor.callLater(2, self.timeout_checker)
for socket_description in self.endpoints:
logger.info("Listening on endpoint %s" % socket_description)
# Twisted requires str on python2 (not unicode) and str on python3 (not bytes)
ep = serverFromString(reactor, str(socket_description))
self.listeners.append(ep.listen(self.http_factory))
self.pool.start()
reactor.addSystemEventTrigger("before", "shutdown", self.pool.stop)
# Set the asyncio reactor's event loop as global
# TODO: Should we instead pass the global one into the reactor?
asyncio.set_event_loop(reactor._asyncioEventloop)
# Verbosity 3 turns on asyncio debug to find those blocking yields
if self.verbosity >= 3:
asyncio.get_event_loop().set_debug(True)
reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications)
reactor.run(installSignalHandlers=self.signal_handlers)
### Protocol handling
@ -115,83 +114,86 @@ class Server(object):
self.protocols.add(protocol)
def discard_protocol(self, protocol):
# Ensure it's not in the protocol-tracking set
self.protocols.discard(protocol)
# Make sure any application future that's running is cancelled
if protocol in self.application_instances:
self.application_instances[protocol].cancel()
del self.application_instances[protocol]
### Internal event/message handling
def create_consumer(self, protocol):
def create_application(self, protocol):
"""
Creates a new consumer instance that fronts a Protocol instance
Creates a new application instance that fronts a Protocol instance
for one of our supported protocols. Pass it the protocol,
and it will work out the type, supply appropriate callables, and
put it into the server's consumer pool.
It returns the consumer channel name, which is how you should refer
to the consumer instance.
return you the application's input queue
"""
# Make sure the protocol defines a consumer type
assert protocol.consumer_type is not None
# Make it a consumer channel name
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
consumer_channel = self.consumer_channel_prefix + protocol_id
# Make an instance of the consumer
consumer_instance = self.consumer(
type=protocol.consumer_type,
# Make sure the protocol defines a application type
assert protocol.application_type is not None
# Make sure the protocol has not had another application made for it
assert protocol not in self.application_instances
# Make an instance of the application
input_queue = asyncio.Queue()
application_instance = asyncio.ensure_future(self.application(
type=protocol.application_type,
next=input_queue.get,
reply=lambda message: self.handle_reply(protocol, message),
channel_layer=self.channel_layer,
consumer_channel=consumer_channel,
)
# Assign it by channel and return it
self.consumer_instances[consumer_channel] = consumer_instance
return consumer_channel
), loop=asyncio.get_event_loop())
self.application_instances[protocol] = application_instance
return input_queue
def handle_message(self, consumer_channel, message):
async def handle_reply(self, protocol, message):
"""
Schedules the application instance to handle the given message.
Coroutine that jumps the reply message from asyncio to Twisted
"""
self.pool.callInThread(self.consumer_instances[consumer_channel], message)
def handle_reply(self, protocol, message):
"""
Schedules the reply to be handled by the protocol in the main thread
"""
reactor.callFromThread(reactor.callLater, 0, protocol.handle_reply, message)
### External event/message handling
def backend_reader(self):
"""
Runs as an-often-as-possible task with the reactor, unless there was
no result previously in which case we add a small delay.
"""
channels = [self.consumer_channel_prefix]
delay = 0
# Quit if reactor is stopping
if not reactor.running:
logger.debug("Backend reader quitting due to reactor stop")
return
# Try to receive a message
try:
channel, message = self.channel_layer.receive(channels, block=False)
except Exception as e:
# Log the error and wait a bit to retry
logger.error('Error trying to receive messages: %s' % e)
delay = 5.00
else:
if channel:
# Deal with the message
try:
self.handle_message(channel, message)
except Exception as e:
logger.error("Error handling external message: %s" % e)
else:
# If there's no messages, idle a little bit.
delay = 0.05
# We can't loop inside here as this is synchronous code.
reactor.callLater(delay, self.backend_reader)
reactor.callLater(0, protocol.handle_reply, message)
### Utility
def application_checker(self):
"""
Goes through the set of current application Futures and cleans up
any that are done/prints exceptions for any that errored.
"""
for protocol, application_instance in list(self.application_instances.items()):
if application_instance.done():
exception = application_instance.exception()
if exception:
logging.error(
"Exception inside application: {}\n{}{}".format(
exception,
"".join(traceback.format_tb(
exception.__traceback__,
)),
" {}".format(exception),
)
)
protocol.handle_exception(exception)
try:
del self.application_instances[protocol]
except KeyError:
# The protocol might have already got here before us. That's fine.
pass
reactor.callLater(1, self.application_checker)
def kill_all_applications(self):
"""
Kills all application coroutines before reactor exit.
"""
# Send cancel to all coroutines
wait_for = []
for application_instance in self.application_instances.values():
if not application_instance.done():
application_instance.cancel()
wait_for.append(application_instance)
logging.info("Killed %i pending application instances" % len(wait_for))
# Make Twisted wait until they're all dead
wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for))
wait_deferred.addErrback(lambda x: None)
return wait_deferred
def timeout_checker(self):
"""
Called periodically to enforce timeout rules on all connections.

View File

@ -20,20 +20,21 @@ class WebSocketProtocol(WebSocketServerProtocol):
the websocket channels.
"""
application_type = "websocket"
# If we should send no more messages (e.g. we error-closed the socket)
muted = False
def __init__(self, *args, **kwargs):
super(WebSocketProtocol, self).__init__(*args, **kwargs)
self.server = self.factory.server_class
def onConnect(self, request):
self.server = self.factory.server_class
self.server.add_protocol(self)
self.request = request
self.packets_received = 0
self.protocol_to_accept = None
self.socket_opened = time.time()
self.last_data = time.time()
try:
# Make new application instance
self.application_queue = self.server.create_application(self)
# Sanitize and decode headers
self.clean_headers = []
for name, value in request.headers.items():
@ -42,10 +43,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
if b"_" in name:
continue
self.clean_headers.append((name.lower(), value.encode("latin1")))
# Make sending channel
self.reply_channel = self.main_factory.make_send_channel()
# Tell main factory about it
self.main_factory.reply_protocols[self.reply_channel] = self
# Get client address if possible
peer = self.transport.getPeer()
host = self.transport.getHost()
@ -56,23 +53,23 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.client_addr = None
self.server_addr = None
if self.main_factory.proxy_forwarded_address_header:
if self.server.proxy_forwarded_address_header:
self.client_addr = parse_x_forwarded_for(
self.http_headers,
self.main_factory.proxy_forwarded_address_header,
self.main_factory.proxy_forwarded_port_header,
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.client_addr
)
# Make initial request info dict from request (we only have it here)
self.path = request.path.encode("ascii")
self.request_info = {
self.connect_message = {
"type": "websocket.connect",
"path": self.unquote(self.path),
"headers": self.clean_headers,
"query_string": self._raw_query_string, # Passed by HTTP protocol
"client": self.client_addr,
"server": self.server_addr,
"reply_channel": self.reply_channel,
"order": 0,
}
except:
@ -86,50 +83,33 @@ class WebSocketProtocol(WebSocketServerProtocol):
if header == b'sec-websocket-protocol':
protocols = [x.strip() for x in self.unquote(value).split(",")]
for protocol in protocols:
if protocol in self.factory.protocols:
if protocol in self.server.websocket_protocols:
ws_protocol = protocol
break
# Work out what subprotocol we will accept, if any
if ws_protocol and ws_protocol in self.factory.protocols:
if ws_protocol and ws_protocol in self.server.websocket_protocols:
self.protocol_to_accept = ws_protocol
else:
self.protocol_to_accept = None
# Send over the connect message
try:
self.channel_layer.send("websocket.connect", self.request_info)
except self.channel_layer.ChannelFull:
# You have to consume websocket.connect according to the spec,
# so drop the connection.
self.muted = True
logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel)
# Send code 503 "Service Unavailable" with close.
raise ConnectionDeny(code=503, reason="Connection queue at capacity")
else:
self.factory.log_action("websocket", "connecting", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
self.application_queue.put_nowait(self.connect_message)
self.server.log_action("websocket", "connecting", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
# Make a deferred and return it - we'll either call it or err it later on
self.handshake_deferred = defer.Deferred()
return self.handshake_deferred
@classmethod
def unquote(cls, value):
"""
Python 2 and 3 compat layer for utf-8 unquoting
"""
if six.PY2:
return unquote(value).decode("utf8")
else:
return unquote(value.decode("ascii"))
### Twisted event handling
def onOpen(self):
# Send news that this channel is open
logger.debug("WebSocket %s open and established", self.reply_channel)
self.factory.log_action("websocket", "connected", {
logger.debug("WebSocket %s open and established", self.client_addr)
self.server.log_action("websocket", "connected", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
@ -137,49 +117,78 @@ class WebSocketProtocol(WebSocketServerProtocol):
def onMessage(self, payload, isBinary):
# If we're muted, do nothing.
if self.muted:
logger.debug("Muting incoming frame on %s", self.reply_channel)
logger.debug("Muting incoming frame on %s", self.client_addr)
return
logger.debug("WebSocket incoming frame on %s", self.reply_channel)
self.packets_received += 1
logger.debug("WebSocket incoming frame on %s", self.client_addr)
self.last_data = time.time()
try:
if isBinary:
self.channel_layer.send("websocket.receive", {
"reply_channel": self.reply_channel,
"path": self.unquote(self.path),
"order": self.packets_received,
"bytes": payload,
})
if isBinary:
self.application_queue.put_nowait({
"type": "websocket.receive",
"bytes": payload,
})
else:
self.application_queue.put_nowait({
"type": "websocket.receive",
"text": payload.decode("utf8"),
})
def onClose(self, wasClean, code, reason):
"""
Called when Twisted closes the socket.
"""
self.server.discard_protocol(self)
logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted:
self.application_queue.put_nowait({
"type": "websocket.disconnect",
"code": code,
})
self.server.log_action("websocket", "disconnected", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
### Internal event handling
def handle_reply(self, message):
if "type" not in message:
raise ValueError("Message has no type defined")
if message["type"] == "websocket.accept":
self.serverAccept()
elif message["type"] == "websocket.close":
if self.state == self.STATE_CONNECTING:
self.serverReject()
else:
self.channel_layer.send("websocket.receive", {
"reply_channel": self.reply_channel,
"path": self.unquote(self.path),
"order": self.packets_received,
"text": payload.decode("utf8"),
})
except self.channel_layer.ChannelFull:
# You have to consume websocket.receive according to the spec,
# so drop the connection.
self.muted = True
logger.warn("WebSocket force closed for %s due to receive backpressure", self.reply_channel)
# Send code 1013 "try again later" with close.
self.sendCloseFrame(code=1013, isReply=False)
self.serverClose(code=message.get("code", None))
elif message["type"] == "websocket.send":
if self.state == self.STATE_CONNECTING:
raise ValueError("Socket has not been accepted, so cannot send over it")
if message.get("bytes", None) and message.get("text", None):
raise ValueError(
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
channel,
)
)
if message.get("bytes", None):
self.serverSend(message["bytes"], True)
if message.get("text", None):
self.serverSend(message["text"], False)
def serverAccept(self):
"""
Called when we get a message saying to accept the connection.
"""
self.handshake_deferred.callback(self.protocol_to_accept)
logger.debug("WebSocket %s accepted by application", self.reply_channel)
logger.debug("WebSocket %s accepted by application", self.client_addr)
def serverReject(self):
"""
Called when we get a message saying to reject the connection.
"""
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
self.cleanup()
logger.debug("WebSocket %s rejected by application", self.reply_channel)
self.factory.log_action("websocket", "rejected", {
self.server.discard_protocol(self)
logger.debug("WebSocket %s rejected by application", self.client_addr)
self.server.log_action("websocket", "rejected", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
@ -191,47 +200,30 @@ class WebSocketProtocol(WebSocketServerProtocol):
if self.state == self.STATE_CONNECTING:
self.serverAccept()
self.last_data = time.time()
logger.debug("Sent WebSocket packet to client for %s", self.reply_channel)
logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
if binary:
self.sendMessage(content, binary)
else:
self.sendMessage(content.encode("utf8"), binary)
def serverClose(self, code=True):
def serverClose(self, code=None):
"""
Server-side channel message to close the socket
"""
code = 1000 if code is True else code
code = 1000 if code is None else code
self.sendClose(code=code)
def onClose(self, wasClean, code, reason):
self.cleanup()
if hasattr(self, "reply_channel"):
logger.debug("WebSocket closed for %s", self.reply_channel)
try:
if not self.muted:
self.channel_layer.send("websocket.disconnect", {
"reply_channel": self.reply_channel,
"code": code,
"path": self.unquote(self.path),
"order": self.packets_received + 1,
})
except self.channel_layer.ChannelFull:
pass
self.factory.log_action("websocket", "disconnected", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
else:
logger.debug("WebSocket closed before handshake established")
### Utils
def cleanup(self):
@classmethod
def unquote(cls, value):
"""
Call to clean up this socket after it's closed.
Python 2 and 3 compat layer for utf-8 unquoting
"""
if hasattr(self, "reply_channel"):
if self.reply_channel in self.factory.reply_protocols:
del self.factory.reply_protocols[self.reply_channel]
if six.PY2:
return unquote(value).decode("utf8")
else:
return unquote(value.decode("ascii"))
def duration(self):
"""
@ -275,3 +267,15 @@ class WebSocketFactory(WebSocketServerFactory):
def __init__(self, server_class, *args, **kwargs):
self.server_class = server_class
WebSocketServerFactory.__init__(self, *args, **kwargs)
def buildProtocol(self, addr):
"""
Builds protocol instances. We use this to inject the factory object into the protocol.
"""
try:
protocol = super(WebSocketFactory, self).buildProtocol(addr)
protocol.factory = self
return protocol
except Exception as e:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise

View File

@ -24,7 +24,7 @@ setup(
include_package_data=True,
install_requires=[
'asgiref~=1.1',
'twisted>=17.1',
'twisted>=17.5',
'autobahn>=0.18',
],
extras_require={