Change to a full pattern-based routing system.

This commit is contained in:
Andrew Godwin 2016-03-20 12:27:42 -03:00
parent a914cfdcb6
commit 841e19da79
9 changed files with 539 additions and 106 deletions

View File

@ -4,7 +4,7 @@ import django
from django.conf import settings
from django.utils.module_loading import import_string
from .consumer_registry import ConsumerRegistry
from .routing import Router
from .utils import name_that_thing
@ -67,7 +67,7 @@ class ChannelLayerWrapper(object):
self.channel_layer = channel_layer
self.alias = alias
self.routing = routing
self.registry = ConsumerRegistry(self.routing)
self.router = Router(self.routing)
def __getattr__(self, name):
return getattr(self.channel_layer, name)

View File

@ -1,76 +0,0 @@
from __future__ import unicode_literals
import importlib
from django.core.exceptions import ImproperlyConfigured
from django.utils import six
from .handler import ViewConsumer
from .utils import name_that_thing
class ConsumerRegistry(object):
"""
Manages the available consumers in the project and which channels they
listen to.
Generally this is attached to a backend instance as ".registry"
"""
def __init__(self, routing=None):
self.consumers = {}
# Initialise with any routing that was passed in
if routing:
# If the routing was a string, import it
if isinstance(routing, six.string_types):
module_name, variable_name = routing.rsplit(".", 1)
try:
routing = getattr(importlib.import_module(module_name), variable_name)
except (ImportError, AttributeError) as e:
raise ImproperlyConfigured("Cannot import channel routing %r: %s" % (routing, e))
# Load consumers into us
for channel, handler in routing.items():
self.add_consumer(handler, [channel])
def add_consumer(self, consumer, channels):
# Upconvert if you just pass in a string for channels
if isinstance(channels, six.string_types):
channels = [channels]
# Make sure all channels are byte strings
channels = [
channel.decode("ascii") if isinstance(channel, six.binary_type) else channel
for channel in channels
]
# Import any consumer referenced as string
if isinstance(consumer, six.string_types):
module_name, variable_name = consumer.rsplit(".", 1)
try:
consumer = getattr(importlib.import_module(module_name), variable_name)
except (ImportError, AttributeError):
raise ImproperlyConfigured("Cannot import consumer %r" % consumer)
# Register on each channel, checking it's unique
for channel in channels:
if channel in self.consumers:
raise ValueError("Cannot register consumer %s - channel %r already consumed by %s" % (
name_that_thing(consumer),
channel,
name_that_thing(self.consumers[channel]),
))
self.consumers[channel] = consumer
def all_channel_names(self):
return self.consumers.keys()
def consumer_for_channel(self, channel):
try:
return self.consumers[channel]
except KeyError:
return None
def check_default(self, http_consumer=None):
"""
Checks to see if default handlers need to be registered
for channels, and adds them if they need to be.
"""
if not self.consumer_for_channel("http.request"):
self.add_consumer(http_consumer or ViewConsumer(), ["http.request"])

View File

@ -35,7 +35,7 @@ class Command(RunserverCommand):
return RunserverCommand.inner_run(self, *args, **options)
# Check a handler is registered for http reqs; if not, add default one
self.channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER]
self.channel_layer.registry.check_default(
self.channel_layer.router.check_default(
http_consumer=self.get_consumer(),
)
# Run checks

View File

@ -27,7 +27,7 @@ class Command(BaseCommand):
"Change your settings to use a cross-process channel layer."
)
# Check a handler is registered for http reqs
self.channel_layer.registry.check_default()
self.channel_layer.router.check_default()
# Launch a worker
self.logger.info("Running worker against channel layer %s", self.channel_layer)
# Optionally provide an output callback

220
channels/routing.py Normal file
View File

@ -0,0 +1,220 @@
from __future__ import unicode_literals
import re
import importlib
from django.core.exceptions import ImproperlyConfigured
from django.utils import six
from .handler import ViewConsumer
from .utils import name_that_thing
class Router(object):
"""
Manages the available consumers in the project and which channels they
listen to.
Generally this is attached to a backend instance as ".router"
"""
def __init__(self, routing):
# Resolve routing into a list if it's a dict or string
routing = self.resolve_routing(routing)
# Expand those entries recursively into a flat list of Routes
self.routing = []
for entry in routing:
self.routing.extend(entry.expand_routes())
# Now go through that list and collect channel names into a set
self.channels = {
route.channel
for route in self.routing
}
def add_route(self, route):
"""
Adds a single raw Route to us at the end of the resolution list.
"""
self.routing.append(route)
self.channels.add(route.channel)
def match(self, message):
"""
Runs through our routing and tries to find a consumer that matches
the message/channel. Returns (consumer, extra_kwargs) if it does,
and None if it doesn't.
"""
# TODO: Maybe we can add some kind of caching in here if we can hash
# the message with only matchable keys faster than the search?
for route in self.routing:
match = route.match(message)
if match is not None:
return match
return None
def check_default(self, http_consumer=None):
"""
Adds default handlers for Django's default handling of channels.
"""
# We just add the default Django route to the bottom; if the user
# has defined another http.request handler, it'll get hit first and run.
self.add_route(Route("http.request", http_consumer or ViewConsumer()))
@classmethod
def resolve_routing(cls, routing):
"""
Takes a routing - if it's a string, it imports it, and if it's a
dict, converts it to a list of route()s. Used by this class and Include.
"""
# If the routing was a string, import it
if isinstance(routing, six.string_types):
module_name, variable_name = routing.rsplit(".", 1)
try:
routing = getattr(importlib.import_module(module_name), variable_name)
except (ImportError, AttributeError) as e:
raise ImproperlyConfigured("Cannot import channel routing %r: %s" % (routing, e))
# If the routing is a dict, convert it
if isinstance(routing, dict):
routing = [
Route(channel, consumer)
for channel, consumer in routing.items()
]
return routing
class Route(object):
"""
Represents a route to a single consumer, with a channel name
and optional message parameter matching.
"""
def __init__(self, channel, consumer, **kwargs):
# Get channel, make sure it's a unicode string
self.channel = channel
if isinstance(self.channel, six.binary_type):
self.channel = self.channel.decode("ascii")
# Get consumer, optionally importing it
if isinstance(consumer, six.string_types):
module_name, variable_name = consumer.rsplit(".", 1)
try:
consumer = getattr(importlib.import_module(module_name), variable_name)
except (ImportError, AttributeError):
raise ImproperlyConfigured("Cannot import consumer %r" % consumer)
self.consumer = consumer
# Compile filter regexes up front
self.filters = {
name: re.compile(value)
for name, value in kwargs.items()
}
# Check filters don't use positional groups
for name, regex in self.filters.items():
if regex.groups != len(regex.groupindex):
raise ValueError(
"Filter for %s on %s contains positional groups; "
"only named groups are allowed." % (
name,
self,
)
)
def match(self, message):
"""
Checks to see if we match the Message object. Returns
(consumer, kwargs dict) if it matches, None otherwise
"""
# Check for channel match first of all
if message.channel.name != self.channel:
return None
# Check each message filter and build consumer kwargs as we go
call_args = {}
for name, value in self.filters.items():
if name not in message:
return None
match = re.match(value, message[name])
# Any match failure means we pass
if match:
call_args.update(match.groupdict())
else:
return None
return self.consumer, call_args
def expand_routes(self):
"""
Expands this route into a list of just itself.
"""
return [self]
def add_prefixes(self, prefixes):
"""
Returns a new Route with the given prefixes added to our filters.
"""
new_filters = {}
# Copy over our filters adding any prefixes
for name, value in self.filters.items():
if name in prefixes:
if not value.pattern.startswith("^"):
raise ValueError("Cannot add prefix for %s on %s as inner value does not start with ^" % (
name,
self,
))
if "$" in prefixes[name]:
raise ValueError("Cannot add prefix for %s on %s as prefix contains $ (end of line match)" % (
name,
self,
))
new_filters[name] = re.compile(prefixes[name] + value.pattern.lstrip("^"))
else:
new_filters[name] = value
# Now add any prefixes that are by themselves so they're still enforced
for name, prefix in prefixes.items():
if name not in new_filters:
new_filters[name] = prefix
# Return new copy
return self.__class__(
self.channel,
self.consumer,
**new_filters
)
def __str__(self):
return "%s %s -> %s" % (
self.channel,
"" if not self.filters else "(%s)" % (
", ".join("%s=%s" % (n, v.pattern) for n, v in self.filters.items())
),
name_that_thing(self.consumer),
)
class Include(object):
"""
Represents an inclusion of another routing list in another file.
Will automatically modify message match filters to add prefixes,
if specified.
"""
def __init__(self, routing, **kwargs):
self.routing = Routing.resolve_routing(routing)
self.prefixes = kwargs
# Sanity check prefix regexes
for name, value in self.prefixes.items():
if not value.startswith("^"):
raise ValueError("Include prefix for %s must start with the ^ character." % name)
def expand_routes(self):
"""
Expands this Include into a list of routes, first recursively expanding
and then adding on prefixes to filters if specified.
"""
# First, expand our own subset of routes, to get a list of Route objects
routes = []
for entry in self.routing:
routes.extend(entry.expand_routes())
# Then, go through those and add any prefixes we have.
routes = [route.add_prefixes(self.prefixes) for route in routes]
return routes
# Lowercase standard to match urls.py
route = Route
include = Include

View File

@ -0,0 +1,234 @@
from __future__ import unicode_literals
from django.test import SimpleTestCase
from django.utils import six
from channels.routing import Router, route, include
from channels.message import Message
from channels.utils import name_that_thing
# Fake consumers and routing sets that can be imported by string
def consumer_1():
pass
def consumer_2():
pass
def consumer_3():
pass
chatroom_routing = [
route("websocket.connect", consumer_2, path=r"^/chat/(?P<room>[^/]+)/$"),
route("websocket.connect", consumer_3, path=r"^/mentions/$"),
]
chatroom_routing_noprefix = [
route("websocket.connect", consumer_2, path=r"/chat/(?P<room>[^/]+)/$"),
route("websocket.connect", consumer_3, path=r"/mentions/$"),
]
class RoutingTests(SimpleTestCase):
"""
Tests that the router's routing code works correctly.
"""
# Fake consumers we can test for with the == operator
def consumer_1(self):
pass
def consumer_2(self):
pass
def consumer_3(self):
pass
def assertRoute(self, router, channel, content, consumer, kwargs=None):
"""
Asserts that asking the `router` to route the `content` as a message
from `channel` means it returns consumer `consumer`, optionally
testing it also returns `kwargs` to be passed in
Use `consumer` = None to assert that no route is found.
"""
message = Message(content, channel, channel_layer="fake channel layer")
match = router.match(message)
if match is None:
if consumer is None:
return
else:
self.fail("No route found for %s on %s; expecting %s" % (
content,
channel,
name_that_thing(consumer),
))
else:
mconsumer, mkwargs = match
if consumer is None:
self.fail("Route found for %s on %s; expecting no route." % (
content,
channel,
))
self.assertEqual(consumer, mconsumer, "Route found for %s on %s; but wrong consumer (%s not %s)." % (
content,
channel,
name_that_thing(mconsumer),
name_that_thing(consumer),
))
if kwargs is not None:
self.assertEqual(kwargs, mkwargs, "Route found for %s on %s; but wrong kwargs (%s not %s)." % (
content,
channel,
mkwargs,
kwargs,
))
def test_assumption(self):
"""
Ensures the test consumers don't compare equal, as if this ever happens
this test file will pass and miss most bugs.
"""
self.assertNotEqual(consumer_1, consumer_2)
self.assertNotEqual(consumer_1, consumer_3)
def test_dict(self):
"""
Tests dict expansion
"""
router = Router({
"http.request": consumer_1,
"http.disconnect": consumer_2,
})
self.assertRoute(
router,
channel="http.request",
content={},
consumer=consumer_1,
kwargs={},
)
self.assertRoute(
router,
channel="http.request",
content={"path": "/chat/"},
consumer=consumer_1,
kwargs={},
)
self.assertRoute(
router,
channel="http.disconnect",
content={},
consumer=consumer_2,
kwargs={},
)
def test_filters(self):
"""
Tests that filters catch things correctly.
"""
router = Router([
route("http.request", consumer_1, path=r"^/chat/$"),
route("http.disconnect", consumer_2),
route("http.request", consumer_3),
])
# Filter hit
self.assertRoute(
router,
channel="http.request",
content={"path": "/chat/"},
consumer=consumer_1,
kwargs={},
)
# Fall-through
self.assertRoute(
router,
channel="http.request",
content={},
consumer=consumer_3,
kwargs={},
)
self.assertRoute(
router,
channel="http.request",
content={"path": "/liveblog/"},
consumer=consumer_3,
kwargs={},
)
def test_include(self):
"""
Tests inclusion without a prefix
"""
router = Router([
include("channels.tests.test_routing.chatroom_routing"),
])
self.assertRoute(
router,
channel="websocket.connect",
content={"path": "/boom/"},
consumer=None,
)
self.assertRoute(
router,
channel="websocket.connect",
content={"path": "/chat/django/"},
consumer=consumer_2,
kwargs={"room": "django"},
)
self.assertRoute(
router,
channel="websocket.connect",
content={"path": "/mentions/"},
consumer=consumer_3,
kwargs={},
)
def test_include_prefix(self):
"""
Tests inclusion with a prefix
"""
router = Router([
include("channels.tests.test_routing.chatroom_routing", path="^/ws/v(?P<version>[0-9]+)"),
])
self.assertRoute(
router,
channel="websocket.connect",
content={"path": "/boom/"},
consumer=None,
)
self.assertRoute(
router,
channel="websocket.connect",
content={"path": "/chat/django/"},
consumer=None,
)
self.assertRoute(
router,
channel="websocket.connect",
content={"path": "/ws/v2/chat/django/"},
consumer=consumer_2,
kwargs={"version": "2", "room": "django"},
)
self.assertRoute(
router,
channel="websocket.connect",
content={"path": "/ws/v1/mentions/"},
consumer=consumer_3,
kwargs={"version": "1"},
)
def test_positional_pattern(self):
"""
Tests that regexes with positional groups are rejected.
"""
with self.assertRaises(ValueError):
Consumerrouter([
route("http.request", consumer_1, path=r"^/chat/([^/]+)/$"),
])
def test_bad_include_prefix(self):
"""
Tests both failure cases of prefixes for includes - the include not
starting with ^, and the included filter not starting with ^.
"""
with self.assertRaises(ValueError):
Consumerrouter([
include("channels.tests.test_routing.chatroom_routing", path="foobar"),
])
with self.assertRaises(ValueError):
Consumerrouter([
include("channels.tests.test_routing.chatroom_routing_noprefix", path="^/foobar/"),
])

View File

@ -44,7 +44,7 @@ class Worker(object):
"""
if self.signal_handlers:
self.install_signal_handler()
channels = self.channel_layer.registry.all_channel_names()
channels = self.channel_layer.router.channels
while not self.termed:
self.in_job = False
channel, content = self.channel_layer.receive_many(channels, block=True)
@ -66,11 +66,16 @@ class Worker(object):
if content.get("__retries__", 0) == self.message_retries:
message.__doomed__ = True
# Handle the message
consumer = self.channel_layer.registry.consumer_for_channel(channel)
match = self.channel_layer.router.match(message)
if match is None:
logger.exception("Could not find match for message on %s! Check your routing.", channel)
continue
else:
consumer, kwargs = match
if self.callback:
self.callback(channel, message)
try:
consumer(message)
consumer(message, **kwargs)
except ConsumeLater:
# They want to not handle it yet. Re-inject it with a number-of-tries marker.
content['__retries__'] = content.get("__retries__", 0) + 1

View File

@ -62,9 +62,10 @@ Here's what that looks like::
}
# In routing.py
channel_routing = {
"http.request": "myproject.myapp.consumers.http_consumer",
}
from channels.routing import route
channel_routing = [
route("http.request", "myproject.myapp.consumers.http_consumer"),
]
.. warning::
This example, and most of the examples here, use the "in memory" channel
@ -76,7 +77,7 @@ Here's what that looks like::
As you can see, this is a little like Django's ``DATABASES`` setting; there are
named channel layers, with a default one called ``default``. Each layer
needs a channel layer class, some options (if the channel layer needs them),
and a routing scheme, which points to a dict containing the routing settings.
and a routing scheme, which points to a list containing the routing settings.
It's recommended you call this ``routing.py`` and put it alongside ``urls.py``
in your project, but you can put it wherever you like, as long as the path is
correct.
@ -111,11 +112,12 @@ for ``http.request`` - and make this WebSocket consumer instead::
Hook it up to the ``websocket.receive`` channel like this::
# In routing.py
from channels.routing import route
from myproject.myapp.consumers import ws_message
channel_routing = {
"websocket.receive": ws_message,
}
channel_routing = [
route("websocket.receive", ws_message),
]
Now, let's look at what this is doing. It's tied to the
``websocket.receive`` channel, which means that it'll get a message
@ -210,12 +212,13 @@ get the message. Here's all the code::
And what our routing should look like in ``routing.py``::
from channels.routing import route
from myproject.myapp.consumers import ws_add, ws_message, ws_disconnect
channel_routing = {
"websocket.connect": ws_add,
"websocket.receive": ws_message,
"websocket.disconnect": ws_disconnect,
channel_routing = [
route("websocket.connect", ws_add),
route("websocket.receive", ws_message),
route("websocket.disconnect", ws_disconnect),
}
With all that code, you now have a working set of a logic for a chat server.
@ -366,6 +369,7 @@ If you play around with it from the console (or start building a simple
JavaScript chat client that appends received messages to a div), you'll see
that you can set a chat room with the initial request.
Authentication
--------------
@ -430,8 +434,6 @@ chat to people with the same first letter of their username::
# Connected to websocket.connect
@channel_session_user_from_http
def ws_add(message):
# Copy user from HTTP to channel session
transfer_user(message.http_session, message.channel_session)
# Add them to the right group
Group("chat-%s" % message.user.username[0]).add(message.reply_channel)
@ -458,6 +460,58 @@ responses can set cookies, it needs a backend it can write to to separately
store state.
Routing
-------
Channels' ``routing.py`` acts very much like Django's ``urls.py``, including the
ability to route things to different consumers based on ``path``, or any other
message attribute that's a string (for example, ``http.request`` messages have
a ``method`` key you could route based on).
Much like urls, you route using regular expressions; the main difference is that
because the ``path`` is not special-cased - Channels doesn't know that it's a URL -
you have to start patterns with the root ``/``, and end includes without a ``/``
so that when the patterns combine, they work correctly.
Finally, because you're matching against message contents using keyword arguments,
you can only use named groups in your regular expressions! Here's an example of
routing our chat from above::
http_routing = [
route("http.request", poll_consumer, path=r"^/poll/$", method=r"^POST$"),
]
chat_routing = [
route("websocket.connect", chat_connect, path=r"^/(?P<room>[a-zA-Z0-9_]+)/$),
route("websocket.disconnect", chat_disconnect),
]
routing = [
# You can use a string import path as the first argument as well.
include(chat_routing, path=r"^/chat"),
include(http_routing),
]
When Channels loads this routing, it appends any match keys together, so the
``path`` match becomes ``^/chat/(?P<room>[a-zA-Z0-9_]+)/$``. If the include match
or the route match doesn't have the ``^`` character, it will refuse to append them
and error (you can still have matches without ``^`` in either, you just can't
ask Channels to combine them).
Because these matches come through as keyword arguments, we could modify our
consumer above to use a room based on URL rather than username::
# Connected to websocket.connect
@channel_session_user_from_http
def ws_add(message, room):
# Add them to the right group
Group("chat-%s" % room).add(message.reply_channel)
In the next section, we'll change to sending the ``room`` as a part of the
WebSocket message - which you might do if you had a multiplexing client -
but you could use routing there as well.
Models
------

View File

@ -117,19 +117,15 @@ They have the following attributes:
* ``alias``: The alias of this layer.
* ``registry``: An object which represents the layer's mapping of channels
* ``router``: An object which represents the layer's mapping of channels
to consumers. Has the following attributes:
* ``add_consumer(consumer, channels)``: Registers a :ref:`consumer <ref-consumers>`
to handle all channels passed in. ``channels`` should be an iterable of
unicode string names.
* ``channels``: The set of channels this router can handle, as unicode strings
* ``consumer_for_channel(channel)``: Takes a unicode channel name and returns
either a :ref:`consumer <ref-consumers>`, or None, if no consumer is registered.
* ``all_channel_names()``: Returns a list of all channel names this layer has
routed to a consumer. Used by the worker threads to work out what channels
to listen on.
* ``match(message)``: Takes a :ref:`Message <ref-message>` and returns either
a (consumer, kwargs) tuple specifying the consumer to run and the keyword
argument to pass that were extracted via routing patterns, or None,
meaning there's no route available.
.. _ref-asgirequest: