Add class-based consumers

This commit is contained in:
Andrew Godwin 2016-05-25 17:45:38 -07:00
parent cc9057e90c
commit bfacee6319
11 changed files with 503 additions and 22 deletions

View File

@ -6,6 +6,6 @@ DEFAULT_CHANNEL_LAYER = 'default'
try:
from .asgi import channel_layers # NOQA isort:skip
from .channel import Channel, Group # NOQA isort:skip
from .routing import route, include # NOQA isort:skip
from .routing import route, route_class, include # NOQA isort:skip
except ImportError: # No django installed, allow vars to be read
pass

View File

@ -0,0 +1 @@
from .base import BaseConsumer

40
channels/generic/base.py Normal file
View File

@ -0,0 +1,40 @@
from __future__ import unicode_literals
class BaseConsumer(object):
"""
Base class-based consumer class. Provides the mechanisms to be a direct
routing object and a few other things.
Class-based consumers should be used directly in routing with their
filters, like so::
routing = [
JsonWebsocketConsumer(path=r"^/liveblog/(?P<slug>[^/]+)/"),
]
"""
method_mapping = {}
def __init__(self, message, **kwargs):
"""
Constructor, called when a new message comes in (the consumer is
the uninstantiated class, so calling it creates it)
"""
self.message = message
self.dispatch(message, **kwargs)
@classmethod
def channel_names(cls):
"""
Returns a list of channels this consumer will respond to, in our case
derived from the method_mapping class attribute.
"""
return set(cls.method_mapping.keys())
def dispatch(self, message, **kwargs):
"""
Called with the message and all keyword arguments; uses method_mapping
to choose the right method to call.
"""
return getattr(self, self.method_mapping[message.channel.name])(message, **kwargs)

View File

@ -0,0 +1,137 @@
import json
from ..channel import Group
from ..sessions import enforce_ordering
from .base import BaseConsumer
class WebsocketConsumer(BaseConsumer):
"""
Base WebSocket consumer. Provides a general encapsulation for the
WebSocket handling model that other applications can build on.
"""
# You shouldn't need to override this
method_mapping = {
"websocket.connect": "raw_connect",
"websocket.receive": "raw_receive",
"websocket.disconnect": "raw_disconnect",
}
# Set one to True if you want the class to enforce ordering for you
slight_ordering = False
strict_ordering = False
def dispatch(self, message, **kwargs):
"""
Pulls out the path onto an instance variable, and optionally
adds the ordering decorator.
"""
self.path = message['path']
if self.strict_ordering:
return enforce_ordering(super(WebsocketConsumer, self).dispatch(message, **kwargs), slight=False)
elif self.slight_ordering:
return enforce_ordering(super(WebsocketConsumer, self).dispatch(message, **kwargs), slight=True)
else:
return super(WebsocketConsumer, self).dispatch(message, **kwargs)
def connection_groups(self, **kwargs):
"""
Group(s) to make people join when they connect and leave when they
disconnect. Make sure to return a list/tuple, not a string!
"""
return []
def raw_connect(self, message, **kwargs):
"""
Called when a WebSocket connection is opened. Base level so you don't
need to call super() all the time.
"""
for group in self.connection_groups(**kwargs):
Group(group, channel_layer=message.channel_layer).add(message.channel)
self.connect(message, **kwargs)
def connect(self, message, **kwargs):
"""
Called when a WebSocket connection is opened.
"""
pass
def raw_receive(self, message, **kwargs):
"""
Called when a WebSocket frame is received. Decodes it and passes it
to receive().
"""
if "text" in message:
self.receive(text=message['text'], **kwargs)
else:
self.receive(bytes=message['bytes'], **kwargs)
def receive(self, text=None, bytes=None, **kwargs):
"""
Called with a decoded WebSocket frame.
"""
pass
def send(self, text=None, bytes=None):
"""
Sends a reply back down the WebSocket
"""
if text is not None:
self.message.reply_channel.send({"text": text})
elif bytes is not None:
self.message.reply_channel.send({"bytes": bytes})
else:
raise ValueError("You must pass text or bytes")
def group_send(self, name, text=None, bytes=None):
if text is not None:
Group(name, channel_layer=self.message.channel_layer).send({"text": text})
elif bytes is not None:
Group(name, channel_layer=self.message.channel_layer).send({"bytes": bytes})
else:
raise ValueError("You must pass text or bytes")
def disconnect(self, message, **kwargs):
"""
Called when a WebSocket connection is closed. Base level so you don't
need to call super() all the time.
"""
for group in self.connection_groups(**kwargs):
Group(group, channel_layer=message.channel_layer).discard(message.channel)
self.disconnect(message, **kwargs)
def disconnect(self, message, **kwargs):
"""
Called when a WebSocket connection is opened.
"""
pass
class JsonWebsocketConsumer(WebsocketConsumer):
"""
Variant of WebsocketConsumer that automatically JSON-encodes and decodes
messages as they come in and go out. Expects everything to be text; will
error on binary data.
"""
def raw_receive(self, message, **kwargs):
if "text" in message:
self.receive(json.loads(message['text']), **kwargs)
else:
raise ValueError("No text section for incoming WebSocket frame!")
def receive(self, content, **kwargs):
"""
Called with decoded JSON content.
"""
pass
def send(self, content):
"""
Encode the given content as JSON and send it to the client.
"""
super(JsonWebsocketConsumer, self).send(text=json.dumps(content))
def group_send(self, name, content):
super(JsonWebsocketConsumer, self).group_send(name, json.dumps(content))

View File

@ -15,6 +15,9 @@ class Router(object):
listen to.
Generally this is attached to a backend instance as ".router"
Anything can be a routable object as long as it provides a match()
method that either returns (callable, kwargs) or None.
"""
def __init__(self, routing):
@ -89,19 +92,16 @@ class Route(object):
and optional message parameter matching.
"""
def __init__(self, channel, consumer, **kwargs):
# Get channel, make sure it's a unicode string
self.channel = channel
if isinstance(self.channel, six.binary_type):
self.channel = self.channel.decode("ascii")
def __init__(self, channels, consumer, **kwargs):
# Get channels, make sure it's a list of unicode strings
if isinstance(channels, six.string_types):
channels = [channels]
self.channels = [
channel.decode("ascii") if isinstance(channel, six.binary_type) else channel
for channel in channels
]
# Get consumer, optionally importing it
if isinstance(consumer, six.string_types):
module_name, variable_name = consumer.rsplit(".", 1)
try:
consumer = getattr(importlib.import_module(module_name), variable_name)
except (ImportError, AttributeError):
raise ImproperlyConfigured("Cannot import consumer %r" % consumer)
self.consumer = consumer
self.consumer = self._resolve_consumer(consumer)
# Compile filter regexes up front
self.filters = {
name: re.compile(Router.normalise_re_arg(value))
@ -118,13 +118,26 @@ class Route(object):
)
)
def _resolve_consumer(self, consumer):
"""
Turns the consumer from a string into an object if it's a string,
passes it through otherwise.
"""
if isinstance(consumer, six.string_types):
module_name, variable_name = consumer.rsplit(".", 1)
try:
consumer = getattr(importlib.import_module(module_name), variable_name)
except (ImportError, AttributeError):
raise ImproperlyConfigured("Cannot import consumer %r" % consumer)
return consumer
def match(self, message):
"""
Checks to see if we match the Message object. Returns
(consumer, kwargs dict) if it matches, None otherwise
"""
# Check for channel match first of all
if message.channel.name != self.channel:
if message.channel.name not in self.channels:
return None
# Check each message filter and build consumer kwargs as we go
call_args = {}
@ -143,11 +156,11 @@ class Route(object):
"""
Returns the channel names this route listens on
"""
return {self.channel, }
return set(self.channels)
def __str__(self):
return "%s %s -> %s" % (
self.channel,
"/".join(self.channels),
"" if not self.filters else "(%s)" % (
", ".join("%s=%s" % (n, v.pattern) for n, v in self.filters.items())
),
@ -155,6 +168,22 @@ class Route(object):
)
class RouteClass(Route):
"""
Like Route, but targets a class-based consumer rather than a functional
one, meaning it looks for a (class) method called "channels()" on the
object rather than having a single channel passed in.
"""
def __init__(self, consumer, **kwargs):
# Check the consumer provides a method_channels
consumer = self._resolve_consumer(consumer)
if not hasattr(consumer, "channel_names") or not callable(consumer.channel_names):
raise ValueError("The consumer passed to RouteClass has no valid channel_names method")
# Call super with list of channels
super(RouteClass, self).__init__(consumer.channel_names(), consumer, **kwargs)
class Include(object):
"""
Represents an inclusion of another routing list in another file.
@ -212,4 +241,5 @@ class Include(object):
# Lowercase standard to match urls.py
route = Route
route_class = RouteClass
include = Include

View File

@ -1,9 +1,10 @@
from __future__ import unicode_literals
from django.test import SimpleTestCase
from channels.routing import Router, route, include
from channels.routing import Router, route, route_class, include
from channels.message import Message
from channels.utils import name_that_thing
from channels.generic import BaseConsumer
# Fake consumers and routing sets that can be imported by string
@ -19,6 +20,16 @@ def consumer_3():
pass
class TestClassConsumer(BaseConsumer):
method_mapping = {
"test.channel": "some_method",
}
def some_method(self, message, **kwargs):
pass
chatroom_routing = [
route("websocket.connect", consumer_2, path=r"^/chat/(?P<room>[^/]+)/$"),
route("websocket.connect", consumer_3, path=r"^/mentions/$"),
@ -29,6 +40,10 @@ chatroom_routing_nolinestart = [
route("websocket.connect", consumer_3, path=r"/mentions/$"),
]
class_routing = [
route_class(TestClassConsumer, path=r"^/foobar/$"),
]
class RoutingTests(SimpleTestCase):
"""
@ -175,6 +190,32 @@ class RoutingTests(SimpleTestCase):
kwargs={},
)
def test_route_class(self):
"""
Tests route_class with/without prefix
"""
router = Router([
include("channels.tests.test_routing.class_routing"),
])
self.assertRoute(
router,
channel="websocket.connect",
content={"path": "/foobar/"},
consumer=None,
)
self.assertRoute(
router,
channel="test.channel",
content={"path": "/foobar/"},
consumer=TestClassConsumer,
)
self.assertRoute(
router,
channel="test.channel",
content={"path": "/"},
consumer=None,
)
def test_include_prefix(self):
"""
Tests inclusion with a prefix
@ -291,15 +332,16 @@ class RoutingTests(SimpleTestCase):
route("http.request", consumer_1, path=r"^/chat/$"),
route("http.disconnect", consumer_2),
route("http.request", consumer_3),
route_class(TestClassConsumer),
])
# Initial check
self.assertEqual(
router.channels,
{"http.request", "http.disconnect"},
{"http.request", "http.disconnect", "test.channel"},
)
# Dynamically add route, recheck
router.add_route(route("websocket.receive", consumer_1))
self.assertEqual(
router.channels,
{"http.request", "http.disconnect", "websocket.receive"},
{"http.request", "http.disconnect", "websocket.receive", "test.channel"},
)

View File

@ -1,12 +1,14 @@
import types
def name_that_thing(thing):
"""
Returns either the function/class path or just the object's repr
"""
# Instance method
if hasattr(thing, "im_class"):
# Mocks will recurse im_class forever
if hasattr(thing, "mock_calls"):
return "<mock>"
return name_that_thing(thing.im_class) + "." + thing.im_func.func_name
# Other named thing
if hasattr(thing, "__name__"):

View File

@ -72,7 +72,7 @@ class Worker(object):
if self.signal_handlers:
self.install_signal_handler()
channels = self.apply_channel_filters(self.channel_layer.router.channels)
logger.info("Listening on channels %s", ", ".join(channels))
logger.info("Listening on channels %s", ", ".join(sorted(channels)))
while not self.termed:
self.in_job = False
channel, content = self.channel_layer.receive_many(channels, block=True)
@ -82,7 +82,7 @@ class Worker(object):
time.sleep(0.01)
continue
# Create message wrapper
logger.debug("Worker got message on %s: repl %s", channel, content.get("reply_channel", "none"))
logger.debug("Got message on %s (reply %s)", channel, content.get("reply_channel", "none"))
message = Message(
content=content,
channel_name=channel,
@ -103,6 +103,7 @@ class Worker(object):
if self.callback:
self.callback(channel, message)
try:
logger.debug("Dispatching message on %s to %s", channel, name_that_thing(consumer))
consumer(message, **kwargs)
except ConsumeLater:
# They want to not handle it yet. Re-inject it with a number-of-tries marker.

150
docs/generics.rst Normal file
View File

@ -0,0 +1,150 @@
Generic Consumers
=================
Much like Django's class-based views, Channels has class-based consumers.
They provide a way for you to arrange code so it's highly modifiable and
inheritable, at the slight cost of it being harder to figure out the execution
path.
We recommend you use them if you find them valuable; normal function-based
consumers are also entirely valid, however, and may result in more readable
code for simpler tasks.
There is one base class-based consumer class, ``BaseConsumer``, that provides
the pattern for method dispatch and is the thing you can build entirely
custom consumers on top of, and then protocol-specific subclasses that provide
extra utility - for example, the ``WebsocketConsumer`` provides automatic
group management for the connection.
When you use class-based consumers in :doc:`routing <routing>`, you need
to use ``route_class`` rather than ``route``; ``route_class`` knows how to
talk to the class-based consumer and extract the list of channels it needs
to listen on from it directly, rather than making you pass it in explicitly.
Class-based consumers are instantiated once for each message they consume,
so it's safe to store things on ``self`` (in fact, ``self.message`` is the
current message by default).
Base
----
The ``BaseConsumer`` class is the foundation of class-based consumers, and what
you can inherit from if you wish to build your own entirely from scratch.
You use it like this::
from channels.generic import BaseConsumer
class MyConsumer(BaseConsumer):
method_mapping = {
"channel.name.here": "method_name",
}
def method_name(self, message, **kwargs):
pass
All you need to define is the ``method_mapping`` dictionary, which maps
channel names to method names. The base code will take care of the dispatching
for you, and set ``self.message`` to the current message as well.
If you want to perfom more complicated routing, you'll need to override the
``dispatch()`` and ``channel_names()`` methods in order to do the right thing;
remember, though, your channel names cannot change during runtime and must
always be the same for as long as your process runs.
WebSockets
----------
There are two WebSockets generic consumers; one that provides group management,
simpler send/receive methods, and basic method routing, and a subclass which
additionally automatically serializes all messages sent and receives using JSON.
The basic WebSocket generic consumer is used like this::
from channels.generic.websockets import WebsocketConsumer
class MyConsumer(WebsocketConsumer):
# Set to True if you want them, else leave out
strict_ordering = False
slight_ordering = False
def connection_groups(self, **kwargs):
"""
Called to return the list of groups to automatically add/remove
this connection to/from.
"""
return ["test"]
def connect(self, message, **kwargs):
"""
Perform things on connection start
"""
pass
def receive(self, text=None, bytes=None, **kwargs):
"""
Called when a message is received with either text or bytes
filled out.
"""
# Simple echo
self.send(text=text, bytes=bytes)
def disconnect(self, message, **kwargs):
"""
Perform things on connection close
"""
pass
You can call ``self.send`` inside the class to send things to the connection's
``reply_channel`` automatically. Any group names returned from ``connection_groups``
are used to add the socket to when it connects and to remove it from when it
disconnects; you get keyword arguments too if your URL path, say, affects
which group to talk to.
The JSON-enabled consumer looks slightly different::
from channels.generic.websockets import JsonWebsocketConsumer
class MyConsumer(JsonWebsocketConsumer):
# Set to True if you want them, else leave out
strict_ordering = False
slight_ordering = False
def connection_groups(self, **kwargs):
"""
Called to return the list of groups to automatically add/remove
this connection to/from.
"""
return ["test"]
def connect(self, message, **kwargs):
"""
Perform things on connection start
"""
pass
def receive(self, content, **kwargs):
"""
Called when a message is received with decoded JSON content
"""
# Simple echo
self.send(content)
def disconnect(self, message, **kwargs):
"""
Perform things on connection close
"""
pass
For this subclass, ``receive`` only gets a ``content`` parameter that is the
already-decoded JSON as Python datastructures; similarly, ``send`` now only
takes a single argument, which it JSON-encodes before sending down to the
client.
Note that this subclass still can't intercept ``Group.send()`` calls to make
them into JSON automatically, but it does provide ``self.group_send(name, content)``
that will do this for you if you call it explicitly.

View File

@ -29,6 +29,8 @@ Contents:
installation
getting-started
deploying
generics
routing
backends
testing
cross-compat

76
docs/routing.rst Normal file
View File

@ -0,0 +1,76 @@
Routing
=======
Routing in Channels is done using a system similar to that in core Django;
a list of possible routes is provided, and Channels goes through all routes
until a match is found, and then runs the resulting consumer.
The difference comes, however, in the fact that Channels has to route based
on more than just URL; channel name is the main thing routed on, and URL
path is one of many other optional things you can route on, depending on
the protocol (for example, imagine email consumers - they would route on
domain or recipient address instead).
The routing Channels takes is just a list of *routing objects* - the three
built in ones are ``route``, ``route_class`` and ``include``, but any object
that implements the routing interface will work:
* A method called ``match``, taking a single ``message`` as an argument and
returning ``None`` for no match or a tuple of ``(consumer, kwargs)`` if matched.
* A method called ``channel_names``, which returns a set of channel names that
will match, which is fed to the channel layer to listen on them.
The three default routing objects are:
* ``route``: Takes a channel name, a consumer function, and optional filter
keyword arguments.
* ``route_class``: Takes a class-based consumer, and optional filter
keyword arguments. Channel names are taken from the consumer's
``channel_names()`` method.
* ``include``: Takes either a list or string import path to a routing list,
and optional filter keyword arguments.
Filters
-------
Filtering is how you limit matches based on, for example, URLs; you use regular
expressions, like so::
route("websocket.connect", consumers.ws_connect, path=r"^/chat/$")
.. note::
Unlike Django's URL routing, which strips the first slash of a URL for
neatness, Channels includes the first slash, as the routing system is
generic and not designed just for URLs.
You can have multiple filters::
route("email.receive", comment_response, to_address=r".*@example.com$", subject="^reply")
Multiple filters are always combined with logical AND; that is, you need to
match every filter to have the consumer called.
Filters can capture keyword arguments to be passed to your function::
route("websocket.connect", connect_blog, path=r'^/liveblog/(?P<slug>[^/]+)/stream/$')
You can also specify filters on an ``include``::
include("blog_includes", path=r'^/liveblog')
When you specify filters on ``include``, the matched portion of the attribute
is removed for matches inside the include; for example, this arrangement
matches URLs like ``/liveblog/stream/``, because the outside ``include``
strips off the ``/liveblog`` part it matches before passing it inside::
inner_routes = [
route("websocket.connect", connect_blog, path=r'^/stream/'),
]
routing = [
include(inner_routes, path=r'^/liveblog')
]