From dcfaf4122b0fdd591469f74d207c701c2fce16d7 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Mon, 3 Oct 2016 16:38:37 -0700 Subject: [PATCH] Work in progress towards accepting websockets explicitly --- channels/channel.py | 3 + channels/consumer_middleware.py | 87 +++++++++++++++++++++++ channels/exceptions.py | 8 +++ channels/management/commands/runserver.py | 4 ++ channels/routing.py | 10 ++- channels/signals.py | 4 ++ channels/worker.py | 5 +- docs/asgi.rst | 37 +++++++++- docs/getting-started.rst | 8 ++- 9 files changed, 161 insertions(+), 5 deletions(-) create mode 100644 channels/consumer_middleware.py diff --git a/channels/channel.py b/channels/channel.py index b65d65a..4d2796b 100644 --- a/channels/channel.py +++ b/channels/channel.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals from django.utils import six from channels import DEFAULT_CHANNEL_LAYER, channel_layers +from .signals import message_sent class Channel(object): @@ -36,6 +37,8 @@ class Channel(object): if not isinstance(content, dict): raise TypeError("You can only send dicts as content on channels.") self.channel_layer.send(self.name, content) + message_sent.send(sender=self.__class__, channel=self.name, keys=list(content.keys())) + print("didsig", self.name) def __str__(self): return self.name diff --git a/channels/consumer_middleware.py b/channels/consumer_middleware.py new file mode 100644 index 0000000..869863e --- /dev/null +++ b/channels/consumer_middleware.py @@ -0,0 +1,87 @@ +from __future__ import unicode_literals + +import importlib +import threading +from django.conf import settings + +from .exceptions import DenyConnection +from .signals import consumer_started, consumer_finished, message_sent + + +class ConsumerMiddlewareRegistry(object): + """ + Handles registration (via settings object) and generation of consumer + middleware stacks + """ + + fixed_middleware = ["channels.consumer_middleware.ConvenienceMiddleware"] + + def __init__(self): + # Load middleware callables from settings + middleware_paths = self.fixed_middleware + getattr(settings, "CONSUMER_MIDDLEWARE", []) + self.middleware_instances = [] + for path in middleware_paths: + module_name, variable_name = path.rsplit(".", 1) + try: + self.middleware_instances.append(getattr(importlib.import_module(module_name), variable_name)) + except (ImportError, AttributeError) as e: + raise ImproperlyConfigured("Cannot import consumer middleware %r: %s" % (path, e)) + + def make_chain(self, consumer, kwargs): + """ + Returns an instantiated chain of middleware around a final consumer. + """ + next_layer = lambda message: consumer(message, **kwargs) + for middleware_instance in reversed(self.middleware_instances): + next_layer = middleware_instance(next_layer) + return next_layer + + +class ConvenienceMiddleware(object): + """ + Standard middleware which papers over some more explicit parts of ASGI. + """ + + runtime_data = threading.local() + + def __init__(self, consumer): + self.consumer = consumer + + def __call__(self, message): + print("conven", message.channel) + if message.channel.name == "websocket.connect": + # Websocket connect acceptance helper + try: + self.consumer(message) + print ("messages sent", self.get_messages()) + except DenyConnection: + message.reply_channel.send({"accept": False}) + else: + # General path + return self.consumer(message) + + @classmethod + def reset_messages(cls, **kwargs): + """ + Tied to the consumer started/ended signal to reset the messages list. + """ + cls.runtime_data.sent_messages = [] + + consumer_started.connect(lambda **kwargs: reset_messages()) + consumer_finished.connect(lambda **kwargs: reset_messages()) + + @classmethod + def sent_message(cls, channel, keys, **kwargs): + """ + Called by message sending interfaces when messages are sent, + for convenience errors only. Should not be relied upon to get + all messages. + """ + cls.runtime_data.sent_messages = getattr(cls.runtime_data, "sent_messages", []) + [(channel, keys)] + print ("saved now", cls.runtime_data.sent_messages) + + message_sent.connect(lambda channel, keys, **kwargs: sent_message(channel, keys)) + + @classmethod + def get_messages(cls): + return getattr(cls.runtime_data, "sent_messages", []) diff --git a/channels/exceptions.py b/channels/exceptions.py index ffdb5a2..d81af8b 100644 --- a/channels/exceptions.py +++ b/channels/exceptions.py @@ -29,3 +29,11 @@ class RequestAborted(Exception): reading the body. """ pass + + +class DenyConnection(Exception): + """ + Raise during a websocket.connect (or other supported connection) handler + to deny the connection. + """ + pass diff --git a/channels/management/commands/runserver.py b/channels/management/commands/runserver.py index 3b68220..e52b018 100644 --- a/channels/management/commands/runserver.py +++ b/channels/management/commands/runserver.py @@ -121,6 +121,10 @@ class Command(RunserverCommand): msg += "WebSocket CONNECT %(path)s [%(client)s]\n" % details elif protocol == "websocket" and action == "disconnected": msg += "WebSocket DISCONNECT %(path)s [%(client)s]\n" % details + elif protocol == "websocket" and action == "connecting": + msg += "WebSocket HANDSHAKING %(path)s [%(client)s]\n" % details + elif protocol == "websocket" and action == "rejected": + msg += "WebSocket REJECT %(path)s [%(client)s]\n" % details sys.stderr.write(msg) diff --git a/channels/routing.py b/channels/routing.py index 76c2464..cb3ea03 100644 --- a/channels/routing.py +++ b/channels/routing.py @@ -56,8 +56,9 @@ class Router(object): # We also add a no-op websocket.connect consumer to the bottom, as the # spec requires that this is consumed, but Channels does not. Any user # consumer will override this one. Same for websocket.receive. - self.add_route(Route("websocket.connect", null_consumer)) + self.add_route(Route("websocket.connect", connect_consumer)) self.add_route(Route("websocket.receive", null_consumer)) + self.add_route(Route("websocket.disconnect", null_consumer)) @classmethod def resolve_routing(cls, routing): @@ -250,6 +251,13 @@ def null_consumer(*args, **kwargs): """ +def connect_consumer(message, *args, **kwargs): + """ + Accept-all-connections websocket.connect consumer + """ + message.reply_channel.send({"accept": True}) + + # Lowercase standard to match urls.py route = Route route_class = RouteClass diff --git a/channels/signals.py b/channels/signals.py index dc83b94..0a0e575 100644 --- a/channels/signals.py +++ b/channels/signals.py @@ -7,5 +7,9 @@ consumer_finished = Signal() worker_ready = Signal() worker_process_ready = Signal() +# Called when a message is sent directly to a channel. Not called for group +# sends or direct ASGI usage. For convenience/nicer errors only. +message_sent = Signal(providing_args=["channel", "keys"]) + # Connect connection closer to consumer finished as well consumer_finished.connect(close_old_connections) diff --git a/channels/worker.py b/channels/worker.py index 3b67e92..4f48615 100644 --- a/channels/worker.py +++ b/channels/worker.py @@ -13,6 +13,7 @@ from .exceptions import ConsumeLater from .message import Message from .utils import name_that_thing from .signals import worker_ready +from .consumer_middleware import ConsumerMiddlewareRegistry logger = logging.getLogger('django.channels') @@ -40,6 +41,7 @@ class Worker(object): self.exclude_channels = exclude_channels self.termed = False self.in_job = False + self.middleware_registry = ConsumerMiddlewareRegistry() def install_signal_handler(self): signal.signal(signal.SIGTERM, self.sigterm_handler) @@ -117,7 +119,8 @@ class Worker(object): # Send consumer started to manage lifecycle stuff consumer_started.send(sender=self.__class__, environ={}) # Run consumer - consumer(message, **kwargs) + chain = self.middleware_registry.make_chain(consumer, kwargs) + chain(message) except ConsumeLater: # They want to not handle it yet. Re-inject it with a number-of-tries marker. content['__retries__'] = content.get("__retries__", 0) + 1 diff --git a/docs/asgi.rst b/docs/asgi.rst index e452217..919d077 100644 --- a/docs/asgi.rst +++ b/docs/asgi.rst @@ -704,7 +704,7 @@ Keys: * ``reply_channel``: Channel name responses would have been sent on. No longer valid after this message is sent; all messages to it will be dropped. - + * ``path``: Unicode string HTTP path from URL, with percent escapes decoded and UTF8 byte sequences decoded into characters. @@ -731,7 +731,21 @@ Connection Sent when the client initially opens a connection and completes the WebSocket handshake. If sending this raises ``ChannelFull``, the interface -server must close the WebSocket connection with error code 1013. +server must close the connection with either HTTP status code ``503`` or +WebSocket close code ``1013``. + +This message must be responded to on the ``reply_channel`` with a +*Connection Reply* message before the socket will pass messages on the +``receive`` channel. The protocol server should ideally send this message +during the handshake phase of the WebSocket and not complete the handshake +until it gets a reply, returning HTTP status code ``403`` if the connection is +denied. If this is not possible, it must buffer WebSocket frames and not +send them onto ``websocket.receive`` until a reply is received, and if the +connection is rejected, return WebSocket close code ``4403``. + +Receiving a WebSocket *Send/Close* message while waiting for a +*Connection Reply* must make the server accept the connection and then send +the message immediately. Channel: ``websocket.connect`` @@ -768,6 +782,22 @@ Keys: * ``order``: The integer value ``0``. +Connection Reply +'''''''''''''''' + +Sent back on the reply channel from an application when a ``connect`` message +is received to say if the connection should be accepted or dropped. + +Behaviour on WebSocket rejection is defined in the Connection section above. + +Channel: ``websocket.send!`` + +Keys: + +* ``accept``: If the connection should be accepted (``True``) or rejected and + dropped (``False``). + + Receive ''''''' @@ -825,6 +855,9 @@ Send/Close Sends a data frame to the client and/or closes the connection from the server end. If ``ChannelFull`` is raised, wait and try again. +If sent while the connection is waiting for acceptance or rejection, +will accept the connection before the frame is sent. + Channel: ``websocket.send!`` Keys: diff --git a/docs/getting-started.rst b/docs/getting-started.rst index 104c7f5..37663a0 100644 --- a/docs/getting-started.rst +++ b/docs/getting-started.rst @@ -105,7 +105,7 @@ for ``http.request`` - and make this WebSocket consumer instead:: def ws_message(message): # ASGI WebSocket packet-received and send-packet message types - # both have a "text" key for their textual data. + # both have a "text" key for their textual data. message.reply_channel.send({ "text": message.content['text'], }) @@ -165,6 +165,7 @@ disconnect, like this:: # Connected to websocket.connect def ws_add(message): + message.reply_channel.send({"accept": True}) Group("chat").add(message.reply_channel) # Connected to websocket.disconnect @@ -203,6 +204,7 @@ get the message. Here's all the code:: # Connected to websocket.connect def ws_add(message): + message.reply_channel.send({"accept": True}) Group("chat").add(message.reply_channel) # Connected to websocket.receive @@ -363,6 +365,8 @@ name in the path of your WebSocket request (we'll ignore auth for now - that's n # Connected to websocket.connect @channel_session def ws_connect(message): + # Accept connection + message.reply_channel.send({"accept": True}) # Work out room name from path (ignore slashes) room = message.content['path'].strip("/") # Save room in session and add us to the group @@ -462,6 +466,8 @@ chat to people with the same first letter of their username:: # Connected to websocket.connect @channel_session_user_from_http def ws_add(message): + # Accept connection + message.reply_channel.send({"accept": True}) # Add them to the right group Group("chat-%s" % message.user.username[0]).add(message.reply_channel)