From bfacee631903c6ca04b85f1e5228843f0b86fe65 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Wed, 25 May 2016 17:45:38 -0700 Subject: [PATCH] Add class-based consumers --- channels/__init__.py | 2 +- channels/generic/__init__.py | 1 + channels/generic/base.py | 40 +++++++++ channels/generic/websockets.py | 137 ++++++++++++++++++++++++++++++ channels/routing.py | 60 +++++++++---- channels/tests/test_routing.py | 48 ++++++++++- channels/utils.py | 4 +- channels/worker.py | 5 +- docs/generics.rst | 150 +++++++++++++++++++++++++++++++++ docs/index.rst | 2 + docs/routing.rst | 76 +++++++++++++++++ 11 files changed, 503 insertions(+), 22 deletions(-) create mode 100644 channels/generic/__init__.py create mode 100644 channels/generic/base.py create mode 100644 channels/generic/websockets.py create mode 100644 docs/generics.rst create mode 100644 docs/routing.rst diff --git a/channels/__init__.py b/channels/__init__.py index ae21a00..98ce119 100644 --- a/channels/__init__.py +++ b/channels/__init__.py @@ -6,6 +6,6 @@ DEFAULT_CHANNEL_LAYER = 'default' try: from .asgi import channel_layers # NOQA isort:skip from .channel import Channel, Group # NOQA isort:skip - from .routing import route, include # NOQA isort:skip + from .routing import route, route_class, include # NOQA isort:skip except ImportError: # No django installed, allow vars to be read pass diff --git a/channels/generic/__init__.py b/channels/generic/__init__.py new file mode 100644 index 0000000..3d08a87 --- /dev/null +++ b/channels/generic/__init__.py @@ -0,0 +1 @@ +from .base import BaseConsumer diff --git a/channels/generic/base.py b/channels/generic/base.py new file mode 100644 index 0000000..89c387c --- /dev/null +++ b/channels/generic/base.py @@ -0,0 +1,40 @@ +from __future__ import unicode_literals + + +class BaseConsumer(object): + """ + Base class-based consumer class. Provides the mechanisms to be a direct + routing object and a few other things. + + Class-based consumers should be used directly in routing with their + filters, like so:: + + routing = [ + JsonWebsocketConsumer(path=r"^/liveblog/(?P[^/]+)/"), + ] + """ + + method_mapping = {} + + def __init__(self, message, **kwargs): + """ + Constructor, called when a new message comes in (the consumer is + the uninstantiated class, so calling it creates it) + """ + self.message = message + self.dispatch(message, **kwargs) + + @classmethod + def channel_names(cls): + """ + Returns a list of channels this consumer will respond to, in our case + derived from the method_mapping class attribute. + """ + return set(cls.method_mapping.keys()) + + def dispatch(self, message, **kwargs): + """ + Called with the message and all keyword arguments; uses method_mapping + to choose the right method to call. + """ + return getattr(self, self.method_mapping[message.channel.name])(message, **kwargs) diff --git a/channels/generic/websockets.py b/channels/generic/websockets.py new file mode 100644 index 0000000..9faf75d --- /dev/null +++ b/channels/generic/websockets.py @@ -0,0 +1,137 @@ +import json + +from ..channel import Group +from ..sessions import enforce_ordering +from .base import BaseConsumer + + +class WebsocketConsumer(BaseConsumer): + """ + Base WebSocket consumer. Provides a general encapsulation for the + WebSocket handling model that other applications can build on. + """ + + # You shouldn't need to override this + method_mapping = { + "websocket.connect": "raw_connect", + "websocket.receive": "raw_receive", + "websocket.disconnect": "raw_disconnect", + } + + # Set one to True if you want the class to enforce ordering for you + slight_ordering = False + strict_ordering = False + + def dispatch(self, message, **kwargs): + """ + Pulls out the path onto an instance variable, and optionally + adds the ordering decorator. + """ + self.path = message['path'] + if self.strict_ordering: + return enforce_ordering(super(WebsocketConsumer, self).dispatch(message, **kwargs), slight=False) + elif self.slight_ordering: + return enforce_ordering(super(WebsocketConsumer, self).dispatch(message, **kwargs), slight=True) + else: + return super(WebsocketConsumer, self).dispatch(message, **kwargs) + + def connection_groups(self, **kwargs): + """ + Group(s) to make people join when they connect and leave when they + disconnect. Make sure to return a list/tuple, not a string! + """ + return [] + + def raw_connect(self, message, **kwargs): + """ + Called when a WebSocket connection is opened. Base level so you don't + need to call super() all the time. + """ + for group in self.connection_groups(**kwargs): + Group(group, channel_layer=message.channel_layer).add(message.channel) + self.connect(message, **kwargs) + + def connect(self, message, **kwargs): + """ + Called when a WebSocket connection is opened. + """ + pass + + def raw_receive(self, message, **kwargs): + """ + Called when a WebSocket frame is received. Decodes it and passes it + to receive(). + """ + if "text" in message: + self.receive(text=message['text'], **kwargs) + else: + self.receive(bytes=message['bytes'], **kwargs) + + def receive(self, text=None, bytes=None, **kwargs): + """ + Called with a decoded WebSocket frame. + """ + pass + + def send(self, text=None, bytes=None): + """ + Sends a reply back down the WebSocket + """ + if text is not None: + self.message.reply_channel.send({"text": text}) + elif bytes is not None: + self.message.reply_channel.send({"bytes": bytes}) + else: + raise ValueError("You must pass text or bytes") + + def group_send(self, name, text=None, bytes=None): + if text is not None: + Group(name, channel_layer=self.message.channel_layer).send({"text": text}) + elif bytes is not None: + Group(name, channel_layer=self.message.channel_layer).send({"bytes": bytes}) + else: + raise ValueError("You must pass text or bytes") + + def disconnect(self, message, **kwargs): + """ + Called when a WebSocket connection is closed. Base level so you don't + need to call super() all the time. + """ + for group in self.connection_groups(**kwargs): + Group(group, channel_layer=message.channel_layer).discard(message.channel) + self.disconnect(message, **kwargs) + + def disconnect(self, message, **kwargs): + """ + Called when a WebSocket connection is opened. + """ + pass + + +class JsonWebsocketConsumer(WebsocketConsumer): + """ + Variant of WebsocketConsumer that automatically JSON-encodes and decodes + messages as they come in and go out. Expects everything to be text; will + error on binary data. + """ + + def raw_receive(self, message, **kwargs): + if "text" in message: + self.receive(json.loads(message['text']), **kwargs) + else: + raise ValueError("No text section for incoming WebSocket frame!") + + def receive(self, content, **kwargs): + """ + Called with decoded JSON content. + """ + pass + + def send(self, content): + """ + Encode the given content as JSON and send it to the client. + """ + super(JsonWebsocketConsumer, self).send(text=json.dumps(content)) + + def group_send(self, name, content): + super(JsonWebsocketConsumer, self).group_send(name, json.dumps(content)) diff --git a/channels/routing.py b/channels/routing.py index afa2946..83108ba 100644 --- a/channels/routing.py +++ b/channels/routing.py @@ -15,6 +15,9 @@ class Router(object): listen to. Generally this is attached to a backend instance as ".router" + + Anything can be a routable object as long as it provides a match() + method that either returns (callable, kwargs) or None. """ def __init__(self, routing): @@ -89,19 +92,16 @@ class Route(object): and optional message parameter matching. """ - def __init__(self, channel, consumer, **kwargs): - # Get channel, make sure it's a unicode string - self.channel = channel - if isinstance(self.channel, six.binary_type): - self.channel = self.channel.decode("ascii") + def __init__(self, channels, consumer, **kwargs): + # Get channels, make sure it's a list of unicode strings + if isinstance(channels, six.string_types): + channels = [channels] + self.channels = [ + channel.decode("ascii") if isinstance(channel, six.binary_type) else channel + for channel in channels + ] # Get consumer, optionally importing it - if isinstance(consumer, six.string_types): - module_name, variable_name = consumer.rsplit(".", 1) - try: - consumer = getattr(importlib.import_module(module_name), variable_name) - except (ImportError, AttributeError): - raise ImproperlyConfigured("Cannot import consumer %r" % consumer) - self.consumer = consumer + self.consumer = self._resolve_consumer(consumer) # Compile filter regexes up front self.filters = { name: re.compile(Router.normalise_re_arg(value)) @@ -118,13 +118,26 @@ class Route(object): ) ) + def _resolve_consumer(self, consumer): + """ + Turns the consumer from a string into an object if it's a string, + passes it through otherwise. + """ + if isinstance(consumer, six.string_types): + module_name, variable_name = consumer.rsplit(".", 1) + try: + consumer = getattr(importlib.import_module(module_name), variable_name) + except (ImportError, AttributeError): + raise ImproperlyConfigured("Cannot import consumer %r" % consumer) + return consumer + def match(self, message): """ Checks to see if we match the Message object. Returns (consumer, kwargs dict) if it matches, None otherwise """ # Check for channel match first of all - if message.channel.name != self.channel: + if message.channel.name not in self.channels: return None # Check each message filter and build consumer kwargs as we go call_args = {} @@ -143,11 +156,11 @@ class Route(object): """ Returns the channel names this route listens on """ - return {self.channel, } + return set(self.channels) def __str__(self): return "%s %s -> %s" % ( - self.channel, + "/".join(self.channels), "" if not self.filters else "(%s)" % ( ", ".join("%s=%s" % (n, v.pattern) for n, v in self.filters.items()) ), @@ -155,6 +168,22 @@ class Route(object): ) +class RouteClass(Route): + """ + Like Route, but targets a class-based consumer rather than a functional + one, meaning it looks for a (class) method called "channels()" on the + object rather than having a single channel passed in. + """ + + def __init__(self, consumer, **kwargs): + # Check the consumer provides a method_channels + consumer = self._resolve_consumer(consumer) + if not hasattr(consumer, "channel_names") or not callable(consumer.channel_names): + raise ValueError("The consumer passed to RouteClass has no valid channel_names method") + # Call super with list of channels + super(RouteClass, self).__init__(consumer.channel_names(), consumer, **kwargs) + + class Include(object): """ Represents an inclusion of another routing list in another file. @@ -212,4 +241,5 @@ class Include(object): # Lowercase standard to match urls.py route = Route +route_class = RouteClass include = Include diff --git a/channels/tests/test_routing.py b/channels/tests/test_routing.py index 891113f..1a00e43 100644 --- a/channels/tests/test_routing.py +++ b/channels/tests/test_routing.py @@ -1,9 +1,10 @@ from __future__ import unicode_literals from django.test import SimpleTestCase -from channels.routing import Router, route, include +from channels.routing import Router, route, route_class, include from channels.message import Message from channels.utils import name_that_thing +from channels.generic import BaseConsumer # Fake consumers and routing sets that can be imported by string @@ -19,6 +20,16 @@ def consumer_3(): pass +class TestClassConsumer(BaseConsumer): + + method_mapping = { + "test.channel": "some_method", + } + + def some_method(self, message, **kwargs): + pass + + chatroom_routing = [ route("websocket.connect", consumer_2, path=r"^/chat/(?P[^/]+)/$"), route("websocket.connect", consumer_3, path=r"^/mentions/$"), @@ -29,6 +40,10 @@ chatroom_routing_nolinestart = [ route("websocket.connect", consumer_3, path=r"/mentions/$"), ] +class_routing = [ + route_class(TestClassConsumer, path=r"^/foobar/$"), +] + class RoutingTests(SimpleTestCase): """ @@ -175,6 +190,32 @@ class RoutingTests(SimpleTestCase): kwargs={}, ) + def test_route_class(self): + """ + Tests route_class with/without prefix + """ + router = Router([ + include("channels.tests.test_routing.class_routing"), + ]) + self.assertRoute( + router, + channel="websocket.connect", + content={"path": "/foobar/"}, + consumer=None, + ) + self.assertRoute( + router, + channel="test.channel", + content={"path": "/foobar/"}, + consumer=TestClassConsumer, + ) + self.assertRoute( + router, + channel="test.channel", + content={"path": "/"}, + consumer=None, + ) + def test_include_prefix(self): """ Tests inclusion with a prefix @@ -291,15 +332,16 @@ class RoutingTests(SimpleTestCase): route("http.request", consumer_1, path=r"^/chat/$"), route("http.disconnect", consumer_2), route("http.request", consumer_3), + route_class(TestClassConsumer), ]) # Initial check self.assertEqual( router.channels, - {"http.request", "http.disconnect"}, + {"http.request", "http.disconnect", "test.channel"}, ) # Dynamically add route, recheck router.add_route(route("websocket.receive", consumer_1)) self.assertEqual( router.channels, - {"http.request", "http.disconnect", "websocket.receive"}, + {"http.request", "http.disconnect", "websocket.receive", "test.channel"}, ) diff --git a/channels/utils.py b/channels/utils.py index e8f060e..dc96201 100644 --- a/channels/utils.py +++ b/channels/utils.py @@ -1,12 +1,14 @@ import types - def name_that_thing(thing): """ Returns either the function/class path or just the object's repr """ # Instance method if hasattr(thing, "im_class"): + # Mocks will recurse im_class forever + if hasattr(thing, "mock_calls"): + return "" return name_that_thing(thing.im_class) + "." + thing.im_func.func_name # Other named thing if hasattr(thing, "__name__"): diff --git a/channels/worker.py b/channels/worker.py index 935ed80..caafd95 100644 --- a/channels/worker.py +++ b/channels/worker.py @@ -72,7 +72,7 @@ class Worker(object): if self.signal_handlers: self.install_signal_handler() channels = self.apply_channel_filters(self.channel_layer.router.channels) - logger.info("Listening on channels %s", ", ".join(channels)) + logger.info("Listening on channels %s", ", ".join(sorted(channels))) while not self.termed: self.in_job = False channel, content = self.channel_layer.receive_many(channels, block=True) @@ -82,7 +82,7 @@ class Worker(object): time.sleep(0.01) continue # Create message wrapper - logger.debug("Worker got message on %s: repl %s", channel, content.get("reply_channel", "none")) + logger.debug("Got message on %s (reply %s)", channel, content.get("reply_channel", "none")) message = Message( content=content, channel_name=channel, @@ -103,6 +103,7 @@ class Worker(object): if self.callback: self.callback(channel, message) try: + logger.debug("Dispatching message on %s to %s", channel, name_that_thing(consumer)) consumer(message, **kwargs) except ConsumeLater: # They want to not handle it yet. Re-inject it with a number-of-tries marker. diff --git a/docs/generics.rst b/docs/generics.rst new file mode 100644 index 0000000..44f9629 --- /dev/null +++ b/docs/generics.rst @@ -0,0 +1,150 @@ +Generic Consumers +================= + +Much like Django's class-based views, Channels has class-based consumers. +They provide a way for you to arrange code so it's highly modifiable and +inheritable, at the slight cost of it being harder to figure out the execution +path. + +We recommend you use them if you find them valuable; normal function-based +consumers are also entirely valid, however, and may result in more readable +code for simpler tasks. + +There is one base class-based consumer class, ``BaseConsumer``, that provides +the pattern for method dispatch and is the thing you can build entirely +custom consumers on top of, and then protocol-specific subclasses that provide +extra utility - for example, the ``WebsocketConsumer`` provides automatic +group management for the connection. + +When you use class-based consumers in :doc:`routing `, you need +to use ``route_class`` rather than ``route``; ``route_class`` knows how to +talk to the class-based consumer and extract the list of channels it needs +to listen on from it directly, rather than making you pass it in explicitly. + +Class-based consumers are instantiated once for each message they consume, +so it's safe to store things on ``self`` (in fact, ``self.message`` is the +current message by default). + +Base +---- + +The ``BaseConsumer`` class is the foundation of class-based consumers, and what +you can inherit from if you wish to build your own entirely from scratch. + +You use it like this:: + + from channels.generic import BaseConsumer + + class MyConsumer(BaseConsumer): + + method_mapping = { + "channel.name.here": "method_name", + } + + def method_name(self, message, **kwargs): + pass + +All you need to define is the ``method_mapping`` dictionary, which maps +channel names to method names. The base code will take care of the dispatching +for you, and set ``self.message`` to the current message as well. + +If you want to perfom more complicated routing, you'll need to override the +``dispatch()`` and ``channel_names()`` methods in order to do the right thing; +remember, though, your channel names cannot change during runtime and must +always be the same for as long as your process runs. + + +WebSockets +---------- + +There are two WebSockets generic consumers; one that provides group management, +simpler send/receive methods, and basic method routing, and a subclass which +additionally automatically serializes all messages sent and receives using JSON. + +The basic WebSocket generic consumer is used like this:: + + from channels.generic.websockets import WebsocketConsumer + + class MyConsumer(WebsocketConsumer): + + # Set to True if you want them, else leave out + strict_ordering = False + slight_ordering = False + + def connection_groups(self, **kwargs): + """ + Called to return the list of groups to automatically add/remove + this connection to/from. + """ + return ["test"] + + def connect(self, message, **kwargs): + """ + Perform things on connection start + """ + pass + + def receive(self, text=None, bytes=None, **kwargs): + """ + Called when a message is received with either text or bytes + filled out. + """ + # Simple echo + self.send(text=text, bytes=bytes) + + def disconnect(self, message, **kwargs): + """ + Perform things on connection close + """ + pass + +You can call ``self.send`` inside the class to send things to the connection's +``reply_channel`` automatically. Any group names returned from ``connection_groups`` +are used to add the socket to when it connects and to remove it from when it +disconnects; you get keyword arguments too if your URL path, say, affects +which group to talk to. + +The JSON-enabled consumer looks slightly different:: + + from channels.generic.websockets import JsonWebsocketConsumer + + class MyConsumer(JsonWebsocketConsumer): + + # Set to True if you want them, else leave out + strict_ordering = False + slight_ordering = False + + def connection_groups(self, **kwargs): + """ + Called to return the list of groups to automatically add/remove + this connection to/from. + """ + return ["test"] + + def connect(self, message, **kwargs): + """ + Perform things on connection start + """ + pass + + def receive(self, content, **kwargs): + """ + Called when a message is received with decoded JSON content + """ + # Simple echo + self.send(content) + + def disconnect(self, message, **kwargs): + """ + Perform things on connection close + """ + pass + +For this subclass, ``receive`` only gets a ``content`` parameter that is the +already-decoded JSON as Python datastructures; similarly, ``send`` now only +takes a single argument, which it JSON-encodes before sending down to the +client. + +Note that this subclass still can't intercept ``Group.send()`` calls to make +them into JSON automatically, but it does provide ``self.group_send(name, content)`` +that will do this for you if you call it explicitly. diff --git a/docs/index.rst b/docs/index.rst index 92e06e4..97b3f0a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,6 +29,8 @@ Contents: installation getting-started deploying + generics + routing backends testing cross-compat diff --git a/docs/routing.rst b/docs/routing.rst new file mode 100644 index 0000000..d3fd116 --- /dev/null +++ b/docs/routing.rst @@ -0,0 +1,76 @@ +Routing +======= + +Routing in Channels is done using a system similar to that in core Django; +a list of possible routes is provided, and Channels goes through all routes +until a match is found, and then runs the resulting consumer. + +The difference comes, however, in the fact that Channels has to route based +on more than just URL; channel name is the main thing routed on, and URL +path is one of many other optional things you can route on, depending on +the protocol (for example, imagine email consumers - they would route on +domain or recipient address instead). + +The routing Channels takes is just a list of *routing objects* - the three +built in ones are ``route``, ``route_class`` and ``include``, but any object +that implements the routing interface will work: + +* A method called ``match``, taking a single ``message`` as an argument and + returning ``None`` for no match or a tuple of ``(consumer, kwargs)`` if matched. + +* A method called ``channel_names``, which returns a set of channel names that + will match, which is fed to the channel layer to listen on them. + +The three default routing objects are: + +* ``route``: Takes a channel name, a consumer function, and optional filter + keyword arguments. + +* ``route_class``: Takes a class-based consumer, and optional filter + keyword arguments. Channel names are taken from the consumer's + ``channel_names()`` method. + +* ``include``: Takes either a list or string import path to a routing list, + and optional filter keyword arguments. + + +Filters +------- + +Filtering is how you limit matches based on, for example, URLs; you use regular +expressions, like so:: + + route("websocket.connect", consumers.ws_connect, path=r"^/chat/$") + +.. note:: + Unlike Django's URL routing, which strips the first slash of a URL for + neatness, Channels includes the first slash, as the routing system is + generic and not designed just for URLs. + +You can have multiple filters:: + + route("email.receive", comment_response, to_address=r".*@example.com$", subject="^reply") + +Multiple filters are always combined with logical AND; that is, you need to +match every filter to have the consumer called. + +Filters can capture keyword arguments to be passed to your function:: + + route("websocket.connect", connect_blog, path=r'^/liveblog/(?P[^/]+)/stream/$') + +You can also specify filters on an ``include``:: + + include("blog_includes", path=r'^/liveblog') + +When you specify filters on ``include``, the matched portion of the attribute +is removed for matches inside the include; for example, this arrangement +matches URLs like ``/liveblog/stream/``, because the outside ``include`` +strips off the ``/liveblog`` part it matches before passing it inside:: + + inner_routes = [ + route("websocket.connect", connect_blog, path=r'^/stream/'), + ] + + routing = [ + include(inner_routes, path=r'^/liveblog') + ]